1use core::time::Duration;
14use parking_lot::Mutex;
15use std::future::Future;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::sync::watch;
19use tonic::service::Routes;
20use tonic::transport::Server as TonicServer;
21use tsoracle_consensus::ConsensusDriver;
22use tsoracle_core::{Allocator, Clock, SystemClock};
23#[cfg(any(test, feature = "test-fakes"))]
24use tsoracle_core::{CoreError, WindowGrant};
25use tsoracle_proto::v1::tso_service_server::TsoServiceServer;
26
27use crate::service::TsoServiceImpl;
28
29#[derive(Debug, thiserror::Error)]
30pub enum BuildError {
31 #[error("consensus_driver is required")]
32 MissingConsensusDriver,
33 #[error("invalid leader-hint metadata key: {0}")]
39 InvalidLeaderHintKey(#[from] tonic::metadata::errors::InvalidMetadataKey),
40}
41
42#[derive(Debug, thiserror::Error)]
43pub enum ServerError {
44 #[error("transport: {0}")]
45 Transport(#[from] tonic::transport::Error),
46 #[error("consensus: {0}")]
47 Consensus(#[from] tsoracle_consensus::ConsensusError),
48 #[error("core: {0}")]
49 Core(#[from] tsoracle_core::CoreError),
50 #[error("leader-watch task panicked: {payload}")]
54 WatchPanic { payload: String },
55}
56
57#[derive(Clone, Debug)]
58pub enum ServingState {
59 NotServing { leader_endpoint: Option<String> },
60 Serving,
61}
62
63pub struct ServerBuilder {
64 consensus: Option<Arc<dyn ConsensusDriver>>,
65 clock: Option<Arc<dyn Clock>>,
66 window_ahead: Duration,
67 failover_advance: Duration,
68 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
69 tls_config: Option<tonic::transport::ServerTlsConfig>,
70}
71
72impl Default for ServerBuilder {
73 fn default() -> Self {
74 ServerBuilder {
75 consensus: None,
76 clock: None,
77 window_ahead: Duration::from_secs(3),
78 failover_advance: Duration::from_secs(1),
79 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
80 tls_config: None,
81 }
82 }
83}
84
85impl ServerBuilder {
86 pub fn consensus_driver(mut self, driver: Arc<dyn ConsensusDriver>) -> Self {
87 self.consensus = Some(driver);
88 self
89 }
90 pub fn clock(mut self, clock: Arc<dyn Clock>) -> Self {
91 self.clock = Some(clock);
92 self
93 }
94 pub fn window_ahead(mut self, window_ahead: Duration) -> Self {
95 self.window_ahead = window_ahead;
96 self
97 }
98 pub fn failover_advance(mut self, failover_advance: Duration) -> Self {
99 self.failover_advance = failover_advance;
100 self
101 }
102
103 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
109 pub fn tls_config(mut self, cfg: tonic::transport::ServerTlsConfig) -> Self {
110 self.tls_config = Some(cfg);
111 self
112 }
113
114 pub fn build(self) -> Result<Server, BuildError> {
115 crate::leader_hint::validate_key()?;
116 let consensus = self.consensus.ok_or(BuildError::MissingConsensusDriver)?;
117 let clock = self.clock.unwrap_or_else(|| Arc::new(SystemClock));
118 let (state_tx, state_rx) = watch::channel(ServingState::NotServing {
119 leader_endpoint: None,
120 });
121 Ok(Server {
122 consensus,
123 clock,
124 window_ahead: self.window_ahead,
125 failover_advance: self.failover_advance,
126 allocator: Arc::new(Mutex::new(Allocator::new())),
127 state_tx,
128 state_rx,
129 extension_lock: Arc::new(tokio::sync::Mutex::new(())),
130 extension_gate: Arc::new(tokio::sync::RwLock::new(())),
131 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
132 tls_config: self.tls_config,
133 })
134 }
135}
136
137pub struct Server {
138 pub(crate) consensus: Arc<dyn ConsensusDriver>,
139 pub(crate) clock: Arc<dyn Clock>,
140 pub(crate) window_ahead: Duration,
141 pub(crate) failover_advance: Duration,
142 pub(crate) allocator: Arc<Mutex<Allocator>>,
143 pub(crate) state_tx: watch::Sender<ServingState>,
144 pub state_rx: watch::Receiver<ServingState>,
145 pub(crate) extension_lock: Arc<tokio::sync::Mutex<()>>,
151 pub(crate) extension_gate: Arc<tokio::sync::RwLock<()>>,
157 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
158 pub(crate) tls_config: Option<tonic::transport::ServerTlsConfig>,
159}
160
161impl Server {
162 pub fn builder() -> ServerBuilder {
163 ServerBuilder::default()
164 }
165
166 pub(crate) fn step_down_due_to_consensus_rejection(&self, leader_endpoint: Option<String>) {
185 self.allocator.lock().on_leadership_lost();
186 let _ = self
187 .state_tx
188 .send(ServingState::NotServing { leader_endpoint });
189 }
190}
191
192impl Server {
193 pub fn into_router(self) -> (Routes, tokio::task::JoinHandle<Result<(), ServerError>>) {
207 let server = Arc::new(self);
208
209 let watch_server = server.clone();
210 let watch_handle = tokio::spawn(async move {
211 use futures::FutureExt;
212 let outcome =
221 std::panic::AssertUnwindSafe(crate::fence::run_leader_watch(watch_server.clone()))
222 .catch_unwind()
223 .await;
224 match outcome {
225 Ok(result) => {
226 if let Err(ref _e) = result {
227 watch_server.step_down_due_to_consensus_rejection(None);
230 #[cfg(feature = "tracing")]
231 tracing::error!(error = %_e, "leader-watch terminated; serving disabled");
232 }
233 result
234 }
235 Err(panic_payload) => {
236 watch_server.step_down_due_to_consensus_rejection(None);
239 #[cfg(feature = "tracing")]
240 tracing::error!("leader-watch panicked; serving disabled");
241 std::panic::resume_unwind(panic_payload);
242 }
243 }
244 });
245
246 let service = TsoServiceImpl { server };
247 #[allow(unused_mut)]
248 let mut routes = Routes::new(TsoServiceServer::new(service));
249 #[cfg(feature = "reflection")]
250 {
251 #[expect(
252 clippy::expect_used,
253 reason = "`FILE_DESCRIPTOR_SET` is generated by `tsoracle-proto`'s `build.rs` from checked-in `.proto` sources; if it ever fails to decode, the build itself is broken. Tracked by #9."
254 )]
255 let reflection = tonic_reflection::server::Builder::configure()
256 .register_encoded_file_descriptor_set(tsoracle_proto::FILE_DESCRIPTOR_SET)
257 .build_v1()
258 .expect("FILE_DESCRIPTOR_SET emitted by build.rs is always valid");
259 routes = routes.add_service(reflection);
260 }
261 (routes, watch_handle)
262 }
263
264 pub async fn serve(self, addr: SocketAddr) -> Result<(), ServerError> {
265 self.serve_with_shutdown(addr, futures::future::pending())
266 .await
267 }
268
269 pub async fn serve_with_shutdown(
283 self,
284 addr: SocketAddr,
285 shutdown: impl Future<Output = ()> + Send + 'static,
286 ) -> Result<(), ServerError> {
287 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
288 let tls_config = self.tls_config.clone();
289
290 let (routes, mut watch_handle) = self.into_router();
291 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
292
293 let combined_shutdown = async move {
296 tokio::select! {
297 _ = shutdown => {}
298 _ = cancel_rx => {}
299 }
300 };
301
302 let mut tonic = TonicServer::builder();
303 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
304 if let Some(cfg) = tls_config {
305 tonic = tonic.tls_config(cfg).map_err(ServerError::Transport)?;
306 }
307 let serve = tonic
308 .add_routes(routes)
309 .serve_with_shutdown(addr, combined_shutdown);
310 tokio::pin!(serve);
311
312 tokio::select! {
313 biased;
318
319 watch_result = &mut watch_handle => {
320 let _ = cancel_tx.send(());
324 let _ = serve.await;
325 join_to_server_result(watch_result)
326 }
327 serve_result = &mut serve => {
328 watch_handle.abort();
334 serve_result?;
335 Ok(())
336 }
337 }
338 }
339
340 pub async fn serve_with_listener(
361 self,
362 listener: tokio::net::TcpListener,
363 shutdown: impl Future<Output = ()> + Send + 'static,
364 ) -> Result<(), ServerError> {
365 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
366 let tls_config = self.tls_config.clone();
367
368 let (routes, mut watch_handle) = self.into_router();
369 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel::<()>();
370
371 let combined_shutdown = async move {
372 tokio::select! {
373 _ = shutdown => {}
374 _ = cancel_rx => {}
375 }
376 };
377
378 let incoming = tonic::transport::server::TcpIncoming::from(listener);
379
380 let mut tonic = TonicServer::builder();
381 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
382 if let Some(cfg) = tls_config {
383 tonic = tonic.tls_config(cfg).map_err(ServerError::Transport)?;
384 }
385 let serve = tonic
386 .add_routes(routes)
387 .serve_with_incoming_shutdown(incoming, combined_shutdown);
388 tokio::pin!(serve);
389
390 tokio::select! {
391 biased;
392
393 watch_result = &mut watch_handle => {
394 let _ = cancel_tx.send(());
395 let _ = serve.await;
396 join_to_server_result(watch_result)
397 }
398 serve_result = &mut serve => {
399 watch_handle.abort();
400 serve_result?;
401 Ok(())
402 }
403 }
404 }
405}
406
407fn join_to_server_result(
415 join_result: Result<Result<(), ServerError>, tokio::task::JoinError>,
416) -> Result<(), ServerError> {
417 match join_result {
418 Ok(inner) => inner,
419 Err(join_err) if join_err.is_panic() => {
420 let payload = panic_payload_to_string(join_err.into_panic());
421 Err(ServerError::WatchPanic { payload })
422 }
423 Err(_cancelled) => Ok(()),
424 }
425}
426
427fn panic_payload_to_string(panic: Box<dyn std::any::Any + Send>) -> String {
428 if let Some(text) = panic.downcast_ref::<&'static str>() {
429 (*text).to_string()
430 } else if let Some(text) = panic.downcast_ref::<String>() {
431 text.clone()
432 } else {
433 "watch task panicked with non-string payload".to_string()
434 }
435}
436
437#[cfg(any(test, feature = "test-fakes"))]
438impl Server {
439 #[doc(hidden)]
442 pub async fn run_leader_watch_for_tests(self: Arc<Self>) -> Result<(), ServerError> {
443 crate::fence::run_leader_watch(self).await
444 }
445
446 #[doc(hidden)]
451 pub fn try_grant_for_tests(&self, count: u32) -> Result<WindowGrant, CoreError> {
452 self.allocator.lock().try_grant(self.clock.now_ms(), count)
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn panic_payload_to_string_recovers_static_str() {
462 let payload: Box<dyn std::any::Any + Send> = Box::new("watch boom");
465 assert_eq!(panic_payload_to_string(payload), "watch boom");
466 }
467
468 #[test]
469 fn panic_payload_to_string_recovers_owned_string() {
470 let payload: Box<dyn std::any::Any + Send> = Box::new(String::from("formatted"));
473 assert_eq!(panic_payload_to_string(payload), "formatted");
474 }
475
476 #[test]
477 fn panic_payload_to_string_falls_back_for_other_types() {
478 struct Custom;
480 let payload: Box<dyn std::any::Any + Send> = Box::new(Custom);
481 assert_eq!(
482 panic_payload_to_string(payload),
483 "watch task panicked with non-string payload",
484 );
485 }
486
487 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
488 #[test]
489 fn builder_stores_tls_config() {
490 use crate::test_fakes::InMemoryDriver;
491
492 let driver = Arc::new(InMemoryDriver::new());
493 let cfg = tonic::transport::ServerTlsConfig::new();
494 let server = Server::builder()
495 .consensus_driver(driver)
496 .tls_config(cfg)
497 .build()
498 .expect("build with tls_config must succeed");
499 assert!(
500 server.tls_config.is_some(),
501 "tls_config must be stored on Server"
502 );
503 }
504
505 #[tokio::test]
506 async fn join_to_server_result_passes_through_clean_outcome() {
507 let handle = tokio::spawn(async { Ok::<(), ServerError>(()) });
509 let join = handle.await;
510 assert!(matches!(join_to_server_result(join), Ok(())));
511 }
512
513 #[tokio::test]
514 async fn join_to_server_result_forwards_inner_error() {
515 let handle = tokio::spawn(async {
517 Err::<(), ServerError>(ServerError::WatchPanic {
518 payload: "synthetic".into(),
519 })
520 });
521 let join = handle.await;
522 match join_to_server_result(join) {
523 Err(ServerError::WatchPanic { payload }) => assert_eq!(payload, "synthetic"),
524 other => panic!("expected forwarded WatchPanic, got {other:?}"),
525 }
526 }
527
528 #[tokio::test]
529 async fn join_to_server_result_translates_panic_to_watch_panic() {
530 let handle = tokio::spawn(async {
533 panic!("intentional");
534 #[allow(unreachable_code)]
535 Ok::<(), ServerError>(())
536 });
537 let join = handle.await;
538 match join_to_server_result(join) {
539 Err(ServerError::WatchPanic { payload }) => assert!(payload.contains("intentional")),
540 other => panic!("expected WatchPanic, got {other:?}"),
541 }
542 }
543
544 #[tokio::test]
545 async fn join_to_server_result_treats_cancellation_as_clean_exit() {
546 let handle: tokio::task::JoinHandle<Result<(), ServerError>> =
549 tokio::spawn(async { futures::future::pending().await });
550 handle.abort();
551 let join = handle.await;
552 assert!(matches!(join_to_server_result(join), Ok(())));
553 }
554}