Skip to main content

xapi_rs/lrs/resources/
stats.rs

1// SPDX-License-Identifier: GPL-3.0-or-later
2
3//! Track some basic statistics per route. The metrics we collect here are:
4//!
5//! * Number of requests,
6//! * Minimum,
7//! * Avergae, and
8//! * Maximum durations in nano-seconds of servicing a request
9
10use crate::MyError;
11use dashmap::DashMap;
12use rocket::{
13    Orbit, Rocket, Route,
14    fairing::{Fairing, Info, Kind},
15    get,
16    http::Method,
17    routes,
18    serde::json::Json,
19};
20use serde::Serialize;
21use std::sync::{
22    Arc, OnceLock,
23    atomic::{AtomicU64, Ordering},
24};
25use tracing::{error, info};
26
27/// How we identify a route.
28#[derive(Debug, Eq, Hash, PartialEq)]
29struct RouteAttributes {
30    method: Method,
31    path: String,
32    mime: String,
33    rank: isize,
34}
35
36impl From<&Route> for RouteAttributes {
37    fn from(route: &Route) -> RouteAttributes {
38        let mime = if let Some(z_format) = route.format.as_ref() {
39            z_format.to_string()
40        } else {
41            "N/A".to_owned()
42        };
43        RouteAttributes {
44            method: route.method,
45            path: route.uri.origin.path().to_string(),
46            mime,
47            rank: route.rank,
48        }
49    }
50}
51
52// What statistics we track per route.
53#[derive(Debug)]
54struct RouteStats {
55    // total number of requests
56    count: AtomicU64,
57    // minimum, average, and maximum request durations (in nanos)
58    min: AtomicU64,
59    avg: AtomicU64,
60    max: AtomicU64,
61}
62
63impl Default for RouteStats {
64    fn default() -> Self {
65        Self {
66            count: Default::default(),
67            min: AtomicU64::new(u64::MAX),
68            avg: Default::default(),
69            max: Default::default(),
70        }
71    }
72}
73
74static ENDPOINTS: OnceLock<Arc<DashMap<RouteAttributes, RouteStats>>> = OnceLock::new();
75fn endpoints() -> Arc<DashMap<RouteAttributes, RouteStats>> {
76    ENDPOINTS.get_or_init(|| Arc::new(DashMap::new())).clone()
77}
78
79/// Global server metrics fairing.
80pub(crate) struct StatsFairing;
81
82#[rocket::async_trait]
83impl Fairing for StatsFairing {
84    fn info(&self) -> Info {
85        Info {
86            name: "Routes Statistics",
87            kind: Kind::Liftoff | Kind::Shutdown,
88        }
89    }
90
91    /// Populate the endpoints map from known registered routes.
92    async fn on_liftoff(&self, r: &Rocket<Orbit>) {
93        for route in r.routes() {
94            let key = RouteAttributes::from(route);
95            endpoints().insert(key, RouteStats::default());
96        }
97    }
98
99    /// Output @info server stats collected during the run.
100    async fn on_shutdown(&self, _: &Rocket<Orbit>) {
101        let stats = endpoints();
102        let (total_count, total_avg): (u64, u64) = stats
103            .iter()
104            .filter(|e| e.count.load(Ordering::Relaxed) > 0)
105            .fold((0, 0), |(sum_count, sum_avg), e| {
106                (
107                    sum_count + e.count.load(Ordering::Relaxed),
108                    sum_avg + e.avg.load(Ordering::Relaxed),
109                )
110            });
111        let average_duration = total_avg.checked_div(total_count).unwrap_or(0);
112        info!("LaRS stats\n{:?}", stats);
113        info!(
114            "*** Total calls = {}; Average duration = {} ns",
115            total_count, average_duration
116        );
117    }
118}
119
120// Update stats for given route and request duration.
121pub(crate) fn update_stats(route: &Route, duration: u64) {
122    let key = RouteAttributes::from(route);
123    let tmp = endpoints();
124    let tmp = tmp.get_mut(&key);
125    match tmp {
126        Some(endpoint) => {
127            endpoint.min.fetch_min(duration, Ordering::Relaxed);
128            endpoint.max.fetch_max(duration, Ordering::Relaxed);
129            let old_count = endpoint.count.fetch_add(1, Ordering::Relaxed);
130            let old_avg = endpoint.avg.fetch_add(0, Ordering::Relaxed);
131            let new_avg = (old_count * old_avg + duration) / (old_count + 1);
132            endpoint.avg.store(new_avg, Ordering::Relaxed);
133        }
134        _ => error!("Failed finding stats for {}", route),
135    }
136}
137
138#[doc(hidden)]
139pub fn routes() -> Vec<rocket::Route> {
140    routes![stats]
141}
142
143#[derive(Debug, Serialize)]
144struct StatsRecord {
145    method: String,
146    path: String,
147    mime: String,
148    rank: isize,
149    count: u64,
150    min: u64,
151    avg: u64,
152    max: u64,
153}
154
155#[get("/")]
156async fn stats() -> Result<Json<Vec<StatsRecord>>, MyError> {
157    let result = endpoints()
158        .iter()
159        .filter(|x| x.count.load(Ordering::Relaxed) > 0)
160        .map(|x| {
161            let (k, v) = x.pair();
162            StatsRecord {
163                method: k.method.to_string(),
164                path: k.path.clone(),
165                mime: k.mime.clone(),
166                rank: k.rank,
167                count: v.count.load(Ordering::Relaxed),
168                min: v.min.load(Ordering::Relaxed),
169                avg: v.avg.load(Ordering::Relaxed),
170                max: v.max.load(Ordering::Relaxed),
171            }
172        })
173        .collect();
174    Ok(Json(result))
175}