1use std::collections::{HashMap, HashSet};
4use std::pin::Pin;
5use std::rc::Rc;
6use std::sync::{Arc, RwLock, Weak};
7use std::task::{Context, Poll};
8
9use actix_service::{Service, Transform};
10use futures::future::{Future, FutureExt, Ready, TryFutureExt, ok as fut_ok};
11use log::{debug, warn};
12
13use actix_web::body::MessageBody;
14use actix_web::error::Error;
15use actix_web::http::StatusCode;
16use actix_web::{
17 HttpResponse,
18 dev::{ServiceRequest, ServiceResponse},
19};
20
21#[cfg(not(feature = "swagger"))]
22use actix_web::web;
23
24#[cfg(feature = "swagger")]
25use paperclip::actix::web;
26
27use serde::Serialize;
28
29#[cfg(feature = "prometheus")]
30pub use super::prometheus::AsPrometheus;
31
32#[derive(Clone)]
34pub struct BaseStats(pub(super) Arc<RwLock<BaseStatsInner>>);
35
36#[derive(Clone, Serialize)]
38pub struct BaseStatsInner {
39 pub(super) request_started: usize,
40 pub(super) request_finished: usize,
41 pub(super) status_codes: HashMap<u16, usize>,
42}
43
44impl Default for BaseStats {
45 fn default() -> Self {
46 Self(Arc::new(RwLock::new(BaseStatsInner {
47 request_started: 0,
48 request_finished: 0,
49 status_codes: HashMap::new(),
50 })))
51 }
52}
53
54pub struct StatsWrapper(Rc<StatsConfig>);
56
57struct StatsConfig {
59 excludes: HashSet<String>,
60}
61
62impl StatsWrapper {
63 pub fn new(excludes: HashSet<String>) -> Self {
64 Self(Rc::new(StatsConfig { excludes }))
65 }
66}
67
68impl Default for StatsWrapper {
69 fn default() -> Self {
70 let mut excludes = HashSet::with_capacity(2);
71 excludes.insert("/_healthcheck".to_string());
72 excludes.insert("/_ready".to_string());
73 excludes.insert("/_stats".to_string());
74 #[cfg(feature = "prometheus")]
75 excludes.insert("/_prometheus".to_string());
76 Self::new(excludes)
77 }
78}
79
80impl<S, B> Transform<S, ServiceRequest> for StatsWrapper
81where
82 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
83 S::Future: 'static,
84 B: MessageBody,
85{
86 type Response = ServiceResponse<B>;
88 type Error = Error;
89 type InitError = ();
90 type Transform = StatsMiddleware<S>;
91 type Future = Ready<Result<Self::Transform, Self::InitError>>;
92
93 fn new_transform(&self, service: S) -> Self::Future {
94 fut_ok(StatsMiddleware {
95 service,
96 config: self.0.clone(),
97 })
98 }
99}
100
101pub struct StatsMiddleware<S> {
103 service: S,
104 config: Rc<StatsConfig>,
105}
106
107impl<S, B> Service<ServiceRequest> for StatsMiddleware<S>
108where
109 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
110 S::Future: 'static,
111 B: MessageBody,
112{
113 type Response = ServiceResponse<B>;
115 type Error = Error;
116 #[allow(clippy::type_complexity)]
117 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
118
119 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 self.service.poll_ready(cx)
121 }
122
123 fn call(&self, req: ServiceRequest) -> Self::Future {
124 let count_it = !self.config.excludes.contains(req.path());
125
126 let stats_arc_for_request = req.app_data::<web::Data<BaseStats>>();
128
129 if count_it
130 && let Some(stats_arc) = stats_arc_for_request
131 && let Ok(mut stats) = stats_arc.0.write()
132 {
133 stats.request_started += 1;
134 }
135
136 let stats_arc_for_response = stats_arc_for_request.map(|bs| Arc::downgrade(&bs.0));
139
140 let fut = self.service.call(req);
141
142 Box::pin(async move {
143 let res = fut.await;
144
145 let status_code = match &res {
146 Ok(res) => {
147 if let Some(error) = res.response().error()
148 && res.response().head().status != StatusCode::INTERNAL_SERVER_ERROR
149 {
150 debug!("Error in response: {error:?}");
151 }
152 res.status()
153 }
154 Err(err) => err.error_response().status(),
155 };
156
157 if count_it {
158 if let Some(stats_arc) = stats_arc_for_response.and_then(|wbs| Weak::upgrade(&wbs))
160 && let Ok(mut stats) = stats_arc.write()
161 {
162 stats.request_finished += 1;
163 let left = stats.request_started - stats.request_finished;
164 if left > 1 {
165 warn!("Number of unfinished requests: {left}");
166 }
167 *stats.status_codes.entry(status_code.as_u16()).or_insert(0) += 1;
168 }
169 }
170
171 res
172 })
173 }
174}
175
176pub async fn default_healthcheck_handler() -> &'static str {
178 ""
179}
180
181pub async fn default_readiness_handler<S, D>(
183 service_data: web::Data<S>,
184) -> Result<HttpResponse, Error>
185where
186 D: AppDataWrapper,
187 S: StatsPresenter<D>,
188{
189 let fut_res = service_data.is_ready().map(|result| match result {
190 Err(error) => HttpResponse::build(StatusCode::INTERNAL_SERVER_ERROR)
191 .body(format!("Can't check readiness: {error}")),
192 Ok(true) => HttpResponse::build(StatusCode::OK).body("OK".to_string()),
193 Ok(false) => {
194 HttpResponse::build(StatusCode::SERVICE_UNAVAILABLE).body("Not ready yet".to_string())
195 }
196 });
197 Ok(fut_res.await)
198}
199
200pub async fn default_stats_handler<S, D>(
202 base_data: web::Data<BaseStats>,
203 service_data: web::Data<S>,
204) -> Result<HttpResponse, Error>
205where
206 D: AppDataWrapper,
207 S: StatsPresenter<D>,
208{
209 let fut_res = service_data.get_stats().and_then(move |service_stats| {
210 if let Ok(base_stats) = base_data.0.read() {
211 #[allow(clippy::unit_arg)]
212 let output = StatsOutput {
213 base: base_stats.clone(),
214 service: Some(service_stats),
215 };
216
217 fut_ok(
218 HttpResponse::build(StatusCode::OK)
219 .content_type("application/json")
220 .body(serde_json::to_string(&output).unwrap()),
221 )
222 } else {
223 fut_ok(
224 HttpResponse::build(StatusCode::INTERNAL_SERVER_ERROR)
225 .body("Can't acquire stats (1)".to_string()),
226 )
227 }
228 });
229
230 fut_res.await
231}
232
233#[derive(Serialize)]
234pub struct StatsOutput<D: Serialize> {
235 pub(super) base: BaseStatsInner,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
238 pub(super) service: Option<D>,
239}
240
241pub trait StatsPresenter<D: AppDataWrapper> {
292 fn is_ready(&self) -> Pin<Box<dyn Future<Output = Result<bool, Error>>>>;
293 fn get_stats(&self) -> Pin<Box<dyn Future<Output = Result<D, Error>>>>;
294
295 #[cfg(feature = "prometheus")]
296 fn get_prometheus(&self) -> Pin<Box<dyn Future<Output = Result<Vec<String>, Error>>>> {
297 let fut = self
298 .get_stats()
299 .map(|stats_res| stats_res.map(|stats| stats.as_prometheus()));
300 Box::pin(fut)
301 }
302}
303
304#[cfg(feature = "prometheus")]
305pub trait AppDataWrapper: Serialize + AsPrometheus + 'static {}
306#[cfg(not(feature = "prometheus"))]
307pub trait AppDataWrapper: Serialize {}
308
309#[cfg(feature = "prometheus")]
310impl<T> AppDataWrapper for T where T: Serialize + AsPrometheus + 'static {}
311
312#[cfg(not(feature = "prometheus"))]
313impl<T> AppDataWrapper for T where T: Serialize {}
314
315