tower_request_guard/
service.rs1use crate::body::{check_content_length, is_bodyless_method};
2use crate::content_type::matches_content_type;
3use crate::guard::RequestGuard;
4use crate::headers::find_missing_header;
5use crate::response::violation_response;
6use crate::route::RouteGuardConfig;
7use crate::violation::{OnViolation, Violation, ViolationAction};
8use http::{Request, Response};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tower_service::Service;
14
15pub struct RequestGuardService<S> {
17 pub(crate) inner: S,
18 pub(crate) guard: Arc<RequestGuard>,
19}
20
21impl<S: Clone> Clone for RequestGuardService<S> {
22 fn clone(&self) -> Self {
23 Self {
24 inner: self.inner.clone(),
25 guard: self.guard.clone(),
26 }
27 }
28}
29
30impl<S, B, ResBody> Service<Request<B>> for RequestGuardService<S>
31where
32 S: Service<Request<B>, Response = Response<ResBody>> + Clone + Send + 'static,
33 S::Future: Send,
34 S::Error: Send,
35 B: Send + 'static,
36 ResBody: From<String> + Send,
37{
38 type Response = Response<ResBody>;
39 type Error = S::Error;
40 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
41
42 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
43 self.inner.poll_ready(cx)
44 }
45
46 fn call(&mut self, req: Request<B>) -> Self::Future {
47 let guard = self.guard.clone();
48 let mut inner = self.inner.clone();
49 std::mem::swap(&mut self.inner, &mut inner);
50
51 Box::pin(async move {
52 let effective = match req.extensions().get::<RouteGuardConfig>() {
54 Some(route_config) => route_config.merge_with(&guard.config),
55 None => guard.config.clone(),
56 };
57
58 let is_bodyless = is_bodyless_method(req.method());
59
60 if !is_bodyless {
62 if let Some(ref allowed) = effective.allowed_content_types {
63 let content_type = req
64 .headers()
65 .get("content-type")
66 .and_then(|v| v.to_str().ok())
67 .unwrap_or("");
68 if !matches_content_type(content_type, allowed) {
69 let violation = Violation::InvalidContentType {
70 received: content_type.to_string(),
71 allowed: allowed.clone(),
72 };
73 if let Some(resp) = handle_violation(&violation, &guard.on_violation) {
74 return Ok(resp.map(Into::into));
75 }
76 }
77 }
78 }
79
80 if !effective.required_headers.is_empty() {
82 if let Some(missing) =
83 find_missing_header(req.headers(), &effective.required_headers)
84 {
85 let violation = Violation::MissingHeader { header: missing };
86 if let Some(resp) = handle_violation(&violation, &guard.on_violation) {
87 return Ok(resp.map(Into::into));
88 }
89 }
90 }
91
92 if !is_bodyless {
94 if let Some(max) = effective.max_body_size {
95 if let Some(received) = check_content_length(req.headers(), max) {
96 let violation = Violation::BodyTooLarge { max, received };
97 if let Some(resp) = handle_violation(&violation, &guard.on_violation) {
98 return Ok(resp.map(Into::into));
99 }
100 }
101 }
102 }
103
104 if let Some(timeout_duration) = effective.timeout {
117 match tokio::time::timeout(timeout_duration, inner.call(req)).await {
118 Ok(result) => result,
119 Err(_elapsed) => {
120 let violation = Violation::RequestTimeout {
121 timeout_ms: u64::try_from(timeout_duration.as_millis())
122 .unwrap_or(u64::MAX),
123 };
124 let resp = handle_timeout_violation(&violation, &guard.on_violation);
125 Ok(resp.map(Into::into))
126 }
127 }
128 } else {
129 inner.call(req).await
130 }
131 })
132 }
133}
134
135pub(crate) fn handle_violation(
138 violation: &Violation,
139 policy: &OnViolation,
140) -> Option<Response<String>> {
141 match policy {
142 OnViolation::Reject => Some(violation_response(violation)),
143 OnViolation::LogAndPass => {
144 tracing::warn!(?violation, "request guard violation (log-and-pass)");
145 None
146 }
147 OnViolation::Custom(callback) => match callback(violation) {
148 ViolationAction::Reject => Some(violation_response(violation)),
149 ViolationAction::Pass => None,
150 ViolationAction::RespondWith(resp) => Some(resp),
151 },
152 }
153}
154
155pub(crate) fn handle_timeout_violation(
157 violation: &Violation,
158 policy: &OnViolation,
159) -> Response<String> {
160 match policy {
161 OnViolation::Custom(callback) => match callback(violation) {
162 ViolationAction::RespondWith(resp) => resp,
163 _ => violation_response(violation),
164 },
165 _ => violation_response(violation),
166 }
167}