#![deny(
bad_style,
missing_debug_implementations,
missing_docs,
overflowing_literals,
patterns_in_fns_without_body,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unused,
unused_extern_crates,
unused_import_braces,
unused_qualifications,
unused_results
)]
use std::{convert::Infallible, net::IpAddr, time::Duration};
use http_body_util::Full;
use hyper::{server::conn::http1, service::service_fn, Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use log::{debug, error, info};
use snafu::{ResultExt as _, Snafu};
use tokio::{
net::{TcpListener, TcpStream},
sync::{oneshot, RwLock},
task::JoinHandle,
};
struct ProbeTaskHandle {
shutdown_sender: oneshot::Sender<()>,
join_handle: JoinHandle<()>,
}
static SERVICE_STATE: std::sync::RwLock<ServiceState> = std::sync::RwLock::new(ServiceState::Up);
static PROBE_TASK_HANDLE: RwLock<Option<ProbeTaskHandle>> = RwLock::const_new(None);
pub const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_millis(500);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ServiceState {
Up,
Ready,
}
impl ServiceState {
pub const fn as_str(&self) -> &'static str {
match self {
ServiceState::Up => "UP",
ServiceState::Ready => "READY",
}
}
}
impl std::fmt::Display for ServiceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Snafu)]
pub enum ProbeStartError {
AlreadyStarted,
SocketUnavailable {
source: std::io::Error,
},
}
pub fn set_service_state(state: ServiceState) {
let mut state_lock = SERVICE_STATE
.write()
.expect("rwlock poisoning should be impossible with the implemented control flow");
if state != *state_lock {
debug!("Service state change: {} to {}.", *state_lock, state);
*state_lock = state;
}
}
pub fn get_service_state() -> ServiceState {
*SERVICE_STATE
.read()
.expect("rwlock poisoning should be impossible with the implemented control flow")
}
pub async fn start_probe<A>(address: A, port: u16) -> Result<(), ProbeStartError>
where
A: Into<IpAddr>,
{
let mut probe_task_handle = PROBE_TASK_HANDLE.write().await;
if probe_task_handle.is_some() {
return Err(ProbeStartError::AlreadyStarted);
}
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
let ip_address: IpAddr = address.into();
let state = get_service_state();
info!("Service readiness probe listening on http://{ip_address}:{port}/ with initial state {state}");
let listener = TcpListener::bind((ip_address, port))
.await
.context(SocketUnavailableSnafu)?;
let join_handle = tokio::task::spawn(run_probe_server(listener, shutdown_receiver));
*probe_task_handle = Some(ProbeTaskHandle {
shutdown_sender,
join_handle,
});
Ok(())
}
pub async fn stop_probe() {
let Some(ProbeTaskHandle {
shutdown_sender,
join_handle,
}) = PROBE_TASK_HANDLE.write().await.take()
else {
return;
};
let _ = shutdown_sender.send(());
debug!("Shutting down service readiness probe");
if let Err(_elapsed) = tokio::time::timeout(SHUTDOWN_GRACE_PERIOD, join_handle).await {
error!("Error shutting down the service readiness probe");
}
}
async fn run_probe_server(listener: TcpListener, mut shutdown_receiver: oneshot::Receiver<()>) {
loop {
tokio::select! {
accept = listener.accept() => {
match accept {
Ok((stream, _addr)) => {
_ = tokio::spawn(handle_accept(stream));
}
Err(e) => {
error!("Error accepting connection for service readiness probe: {e:?}");
}
}
}
_ = &mut shutdown_receiver => {
return;
}
}
}
}
async fn handle_accept(stream: TcpStream) {
let io = TokioIo::new(stream);
let service = service_fn(handle_request);
if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
error!("Error serving connection for service readiness probe: {e:?}");
}
}
async fn handle_request(
req: Request<hyper::body::Incoming>,
) -> Result<Response<Full<&'static [u8]>>, Infallible> {
let (status_code, body) = match *req.method() {
Method::GET => {
let path = req.uri().path();
if ["", "/", "/health", "/health/"].contains(&path) {
let state = get_service_state().as_str();
(StatusCode::OK, state)
} else {
(StatusCode::NOT_FOUND, "Not found")
}
}
Method::HEAD => (StatusCode::OK, ""),
_ => (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed"),
};
let mut response = Response::new(Full::new(body.as_bytes()));
*response.status_mut() = status_code;
Ok(response)
}