use crate::service::routing::Route;
use crate::service::{Layer, Service};
use conjure_http::server::EndpointMetadata;
use futures_util::ready;
use http::{HeaderMap, Request, Response};
use http_body::Body;
use pin_project::{pin_project, pinned_drop};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::time::Instant;
use witchcraft_metrics::{Meter, MetricId, MetricRegistry, Timer};
pub struct EndpointMetrics {
response: Arc<Timer>,
response_error: Arc<Meter>,
}
impl EndpointMetrics {
#[allow(dead_code)]
pub fn new(metrics: &MetricRegistry, endpoint: &dyn EndpointMetadata) -> Self {
EndpointMetrics {
response: metrics.timer(
MetricId::new("server.response")
.with_tag("service-name", endpoint.service_name().to_string())
.with_tag("endpoint", endpoint.name().to_string()),
),
response_error: metrics.meter(
MetricId::new("server.response.error")
.with_tag("service-name", endpoint.service_name().to_string())
.with_tag("endpoint", endpoint.name().to_string()),
),
}
}
}
pub struct EndpointMetricsLayer;
impl<S> Layer<S> for EndpointMetricsLayer {
type Service = EndpointMetricsService<S>;
fn layer(self, inner: S) -> Self::Service {
EndpointMetricsService { inner }
}
}
pub struct EndpointMetricsService<S> {
inner: S,
}
impl<S, B1, B2> Service<Request<B1>> for EndpointMetricsService<S>
where
S: Service<Request<B1>, Response = Response<B2>>,
{
type Response = Response<EndpointMetricsBody<B2>>;
type Future = EndpointMetricsFuture<S::Future>;
fn call(&self, req: Request<B1>) -> Self::Future {
let endpoint_metrics = match req
.extensions()
.get::<Route>()
.expect("Route missing from request extensions")
{
Route::Resolved(endpoint) => endpoint.metrics(),
_ => None,
};
EndpointMetricsFuture {
start_time: Instant::now(),
response: endpoint_metrics.map(|e| e.response.clone()),
response_error: endpoint_metrics.map(|e| e.response_error.clone()),
inner: self.inner.call(req),
}
}
}
#[pin_project]
pub struct EndpointMetricsFuture<F> {
#[pin]
inner: F,
start_time: Instant,
response: Option<Arc<Timer>>,
response_error: Option<Arc<Meter>>,
}
impl<F, B> Future for EndpointMetricsFuture<F>
where
F: Future<Output = Response<B>>,
{
type Output = Response<EndpointMetricsBody<B>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let response = ready!(this.inner.poll(cx));
if response.status().is_server_error() {
if let Some(response_error) = this.response_error {
response_error.mark(1);
}
}
Poll::Ready(response.map(|inner| EndpointMetricsBody {
inner,
start_time: *this.start_time,
response: this.response.take(),
}))
}
}
#[pin_project(PinnedDrop)]
pub struct EndpointMetricsBody<B> {
#[pin]
inner: B,
start_time: Instant,
response: Option<Arc<Timer>>,
}
#[pinned_drop]
impl<B> PinnedDrop for EndpointMetricsBody<B> {
fn drop(self: Pin<&mut Self>) {
if let Some(response) = &self.response {
response.update(self.start_time.elapsed());
}
}
}
impl<B> Body for EndpointMetricsBody<B>
where
B: Body,
{
type Data = B::Data;
type Error = B::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.project().inner.poll_data(cx)
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
self.project().inner.poll_trailers(cx)
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}