transmit/
metrics.rs

1use crate::contract;
2use crate::model::MetricEvent;
3
4use http_body_util::Full;
5use hyper::server::conn::http1;
6use hyper::service::service_fn;
7use hyper::{body::Bytes, Request, Response};
8use hyper_util::rt::TokioIo;
9use log::info;
10use prometheus_client::encoding::text::encode;
11use prometheus_client::encoding::{EncodeLabelSet, EncodeLabelValue};
12use prometheus_client::metrics::counter::Counter;
13use prometheus_client::metrics::family::Family;
14use prometheus_client::registry::Registry;
15use serde::Deserialize;
16use std::{future::Future, io, net::SocketAddr, pin::Pin, sync::Arc};
17use tokio::net::TcpListener;
18
19const METRIC_NAME: &str = "procedure";
20const METRIC_HELP_TEXT: &str = "Number of procedure calls";
21
22#[derive(Debug, Deserialize)]
23pub struct Config {
24    pub port: u16,
25    pub endpoint: String,
26}
27
28pub fn new(config: Config) -> (MetricClient, MetricServer) {
29    let mut registry = <Registry>::default();
30
31    let procedure_metric = Family::<ResultLabel, Counter>::default();
32
33    registry.register(METRIC_NAME, METRIC_HELP_TEXT, procedure_metric.clone());
34
35    (
36        MetricClient { procedure_metric },
37        MetricServer { config, registry },
38    )
39}
40
41pub struct MetricClient {
42    procedure_metric: Family<ResultLabel, Counter>,
43}
44
45pub struct MetricServer {
46    config: Config,
47    registry: Registry,
48}
49
50impl MetricServer {
51    pub async fn start_server(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
52        let registry = Arc::new(self.registry);
53
54        let address = SocketAddr::from(([0, 0, 0, 0], self.config.port));
55
56        // We create a TcpListener and bind it to the address.
57        let listener = TcpListener::bind(address).await?;
58
59        info!("Starting metrics server on {address}");
60
61        // We start a loop to continuously accept incoming connections.
62        loop {
63            let (stream, _) = listener.accept().await?;
64
65            // Use an adapter to access something implementing `tokio::io` traits as if they implement.
66            // `hyper::rt` IO traits.
67            let io = TokioIo::new(stream);
68
69            let registry = registry.clone();
70            // Spawn a tokio task to serve multiple connections concurrently.
71            tokio::task::spawn(async move {
72                // Finally, we bind the incoming connection to our service.
73                if let Err(err) = http1::Builder::new()
74                    // `service_fn` converts our function in a `Service`.
75                    .serve_connection(io, service_fn(make_handler(registry.clone())))
76                    .await
77                {
78                    println!("Error serving connection: {:?}", err);
79                }
80            });
81        }
82    }
83}
84
85/// make_handler returns a HTTP handler that returns the stored metrics from the registry.
86fn make_handler(
87    registry: Arc<Registry>,
88) -> impl Fn(
89    Request<hyper::body::Incoming>,
90) -> Pin<Box<dyn Future<Output = io::Result<Response<Full<Bytes>>>> + Send>> {
91    // This closure accepts a request and responds with the OpenMetrics encoding of our metrics.
92    move |_req: Request<hyper::body::Incoming>| {
93        let reg = registry.clone();
94
95        Box::pin(async move {
96            let mut buf = String::new();
97            encode(&mut buf, &reg.clone())
98                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
99                .map(|_| {
100                    Response::builder()
101                        .header(
102                            hyper::header::CONTENT_TYPE,
103                            "application/openmetrics-text; version=1.0.0; charset=utf-8",
104                        )
105                        .body(Full::new(buf.into()))
106                        .unwrap()
107                })
108        })
109    }
110}
111
112impl contract::Metrics for MetricClient {
113    fn count(&self, metric_event: MetricEvent) {
114        let metric_label = ResultLabel::from(metric_event);
115        self.procedure_metric.get_or_create(&metric_label).inc();
116    }
117}
118
119#[derive(Clone, Hash, PartialEq, Eq, EncodeLabelSet, Debug)]
120pub struct ResultLabel {
121    pub procedure: Procedure,
122    pub result: ResultStatus,
123}
124
125#[derive(Clone, Hash, PartialEq, Eq, EncodeLabelValue, Debug)]
126pub enum Procedure {
127    Scheduled,
128    Polled,
129    Transmitted,
130    ScheduleStateSaved,
131    Rescheduled,
132}
133
134impl From<MetricEvent> for ResultLabel {
135    fn from(metric_event: MetricEvent) -> ResultLabel {
136        match metric_event {
137            MetricEvent::Scheduled(success) => ResultLabel {
138                procedure: Procedure::Scheduled,
139                result: ResultStatus::from(success),
140            },
141            MetricEvent::Polled(success) => ResultLabel {
142                procedure: Procedure::Polled,
143                result: ResultStatus::from(success),
144            },
145            MetricEvent::Transmitted(success) => ResultLabel {
146                procedure: Procedure::Transmitted,
147                result: ResultStatus::from(success),
148            },
149            MetricEvent::ScheduleStateSaved(success) => ResultLabel {
150                procedure: Procedure::ScheduleStateSaved,
151                result: ResultStatus::from(success),
152            },
153            MetricEvent::Rescheduled(success) => ResultLabel {
154                procedure: Procedure::Rescheduled,
155                result: ResultStatus::from(success),
156            },
157        }
158    }
159}
160
161#[derive(Clone, Hash, PartialEq, Eq, EncodeLabelValue, Debug)]
162pub enum ResultStatus {
163    Success,
164    Failure,
165}
166
167impl From<bool> for ResultStatus {
168    fn from(value: bool) -> ResultStatus {
169        match value {
170            true => ResultStatus::Success,
171            false => ResultStatus::Failure,
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    use prometheus_client::encoding::text::encode;
181
182    use crate::contract::Metrics;
183
184    #[test]
185    fn test_metric_increment() {
186        let config = Config {
187            port: 0,
188            endpoint: "/".to_string(),
189        };
190        let (prometheus_client, prometheus_server) = new(config);
191        prometheus_client.count(MetricEvent::Polled(false));
192
193        let mut buffer = String::new();
194        encode(&mut buffer, &prometheus_server.registry).unwrap();
195
196        let expected = format!(
197            "# HELP {METRIC_NAME} {METRIC_HELP_TEXT}.
198# TYPE {METRIC_NAME} counter
199{METRIC_NAME}_total{{procedure=\"Polled\",result=\"Failure\"}} 1
200# EOF
201",
202        );
203        assert_eq!(expected, buffer);
204    }
205
206    #[tokio::test]
207    async fn test_metric_server() -> Result<(), Box<dyn std::error::Error>> {
208        let port = 8083;
209        let config = Config {
210            port,
211            endpoint: "".to_string(),
212        };
213
214        let (prometheus_client, prometheus_server) = new(config);
215
216        // Spawn thread.
217        let _handle = tokio::spawn(async move {
218            prometheus_server
219                .start_server()
220                .await
221                .expect("prometheus server must start");
222        });
223
224        std::thread::sleep(std::time::Duration::from_millis(5));
225
226        // Cause arbitrary metric mutation.
227        prometheus_client.count(MetricEvent::Scheduled(true));
228
229        // Send metric request.
230        let address = format!("http://localhost:{}", port);
231        let response_body = reqwest::get(address).await?.text().await?;
232
233        // Assert metric body content.
234        let expected_body = format!(
235            "# HELP {METRIC_NAME} {METRIC_HELP_TEXT}.
236# TYPE {METRIC_NAME} counter
237{METRIC_NAME}_total{{procedure=\"Scheduled\",result=\"Success\"}} 1
238# EOF
239",
240        );
241        assert_eq!(expected_body, response_body);
242
243        Ok(())
244    }
245}