service_probe/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
// SPDX-FileCopyrightText: OpenTalk Team <mail@opentalk.eu>
//
// SPDX-License-Identifier: MIT OR Apache-2.0

//! # Service probe
//!
//! This crate provides an easy way to start a HTTP server that can be used for
//! making the status of a service transparent to observers. The main use case is
//! to communicate information about the health status of a service in containerized
//! environments.
//!
//! Tasks and synchronization throughout this crate uses [`tokio`]
//! functionality, so the runtime must be present and running when the functions
//! of this crate are called.
#![deny(
    bad_style,
    missing_debug_implementations,
    missing_docs,
    overflowing_literals,
    patterns_in_fns_without_body,
    trivial_casts,
    trivial_numeric_casts,
    unsafe_code,
    unused,
    unused_extern_crates,
    unused_import_braces,
    unused_qualifications,
    unused_results
)]

use std::{convert::Infallible, net::IpAddr, time::Duration};

use http_body_util::Full;
use hyper::{server::conn::http1, service::service_fn, Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use log::{debug, error, info};
use snafu::{ResultExt as _, Snafu};
use tokio::{
    net::{TcpListener, TcpStream},
    sync::{oneshot, RwLock},
    task::JoinHandle,
};

struct ProbeTaskHandle {
    shutdown_sender: oneshot::Sender<()>,
    join_handle: JoinHandle<()>,
}

static SERVICE_STATE: std::sync::RwLock<ServiceState> = std::sync::RwLock::new(ServiceState::Up);
static PROBE_TASK_HANDLE: RwLock<Option<ProbeTaskHandle>> = RwLock::const_new(None);

/// The grace period given to the probe for shutting itself down.
pub const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_millis(500);

/// The state of a service
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ServiceState {
    /// The service is starting up.
    Up,

    /// The service is started and ready to process requests.
    Ready,
}

impl ServiceState {
    /// Get the [`str`] representation of the [`ServiceState`].
    pub const fn as_str(&self) -> &'static str {
        match self {
            ServiceState::Up => "UP",
            ServiceState::Ready => "READY",
        }
    }
}

impl std::fmt::Display for ServiceState {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.as_str())
    }
}

/// The error that can happen during startup of the service probe.
#[derive(Debug, Snafu)]
pub enum ProbeStartError {
    /// The service probe has been started already.
    AlreadyStarted,

    /// The socket cannot be used for providing the service probe.
    SocketUnavailable {
        /// The source error
        source: std::io::Error,
    },
}

/// Set the state of the service.
///
/// After this function has been called, requests to the probe endpoint will return the new state.
pub fn set_service_state(state: ServiceState) {
    let mut state_lock = SERVICE_STATE
        .write()
        .expect("rwlock poisoning should be impossible with the implemented control flow");
    if state != *state_lock {
        debug!("Service state change: {} to {}.", *state_lock, state);
        *state_lock = state;
    }
}

/// Get the state of the service.
///
/// This is the state that is returned by the probe endpoint.
pub fn get_service_state() -> ServiceState {
    *SERVICE_STATE
        .read()
        .expect("rwlock poisoning should be impossible with the implemented control flow")
}

/// Start the probe HTTP service.
///
/// This opens a HTTP v1 server on the selected address and port which will serve the state in `GET` requests to `/health`.
pub async fn start_probe<A>(address: A, port: u16) -> Result<(), ProbeStartError>
where
    A: Into<IpAddr>,
{
    let mut probe_task_handle = PROBE_TASK_HANDLE.write().await;

    if probe_task_handle.is_some() {
        return Err(ProbeStartError::AlreadyStarted);
    }

    let (shutdown_sender, shutdown_receiver) = oneshot::channel();

    let ip_address: IpAddr = address.into();

    let state = get_service_state();
    info!("Service readiness probe listening on http://{ip_address}:{port}/ with initial state {state}");
    let listener = TcpListener::bind((ip_address, port))
        .await
        .context(SocketUnavailableSnafu)?;

    let join_handle = tokio::task::spawn(run_probe_server(listener, shutdown_receiver));

    *probe_task_handle = Some(ProbeTaskHandle {
        shutdown_sender,
        join_handle,
    });

    Ok(())
}

/// Stop the probe HTTP service.
///
/// There is a grace period defined as [`SHUTDOWN_GRACE_PERIOD`]. If the
/// grace period is exceeded, no further action will be taken, but an error will
/// be logged and this function returns.
pub async fn stop_probe() {
    let Some(ProbeTaskHandle {
        shutdown_sender,
        join_handle,
    }) = PROBE_TASK_HANDLE.write().await.take()
    else {
        return;
    };

    let _ = shutdown_sender.send(());

    debug!("Shutting down service readiness probe");

    if let Err(_elapsed) = tokio::time::timeout(SHUTDOWN_GRACE_PERIOD, join_handle).await {
        error!("Error shutting down the service readiness probe");
    }
}

async fn run_probe_server(listener: TcpListener, mut shutdown_receiver: oneshot::Receiver<()>) {
    loop {
        tokio::select! {
            accept = listener.accept() => {
                match accept {
                    Ok((stream, _addr)) => {
                        _ = tokio::spawn(handle_accept(stream));
                    }
                    Err(e) => {
                        error!("Error accepting connection for service readiness probe: {e:?}");
                    }
                }
            }
            _ = &mut shutdown_receiver => {
                return;
            }
        }
    }
}

async fn handle_accept(stream: TcpStream) {
    let io = TokioIo::new(stream);
    let service = service_fn(handle_request);

    if let Err(e) = http1::Builder::new().serve_connection(io, service).await {
        error!("Error serving connection for service readiness probe: {e:?}");
    }
}

async fn handle_request(
    req: Request<hyper::body::Incoming>,
) -> Result<Response<Full<&'static [u8]>>, Infallible> {
    let (status_code, body) = match *req.method() {
        Method::GET => {
            let path = req.uri().path();
            if ["", "/", "/health", "/health/"].contains(&path) {
                let state = get_service_state().as_str();
                (StatusCode::OK, state)
            } else {
                (StatusCode::NOT_FOUND, "Not found")
            }
        }
        Method::HEAD => (StatusCode::OK, ""),
        _ => (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed"),
    };
    let mut response = Response::new(Full::new(body.as_bytes()));
    *response.status_mut() = status_code;
    Ok(response)
}