Skip to main content

tower_request_guard/
violation.rs

1use http::{Response, StatusCode};
2use std::sync::Arc;
3
4/// A request validation violation detected by the guard.
5#[derive(Debug, Clone)]
6pub enum Violation {
7    BodyTooLarge {
8        max: u64,
9        received: u64,
10    },
11    RequestTimeout {
12        timeout_ms: u64,
13    },
14    InvalidContentType {
15        received: String,
16        allowed: Vec<String>,
17    },
18    MissingHeader {
19        header: String,
20    },
21    JsonTooDeep {
22        max_depth: u32,
23        found_depth: u32,
24    },
25    InvalidJson {
26        detail: String,
27    },
28}
29
30impl Violation {
31    /// Returns the appropriate HTTP status code for this violation.
32    pub fn status_code(&self) -> StatusCode {
33        match self {
34            Self::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
35            Self::RequestTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
36            Self::InvalidContentType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE,
37            Self::MissingHeader { .. } => StatusCode::BAD_REQUEST,
38            Self::JsonTooDeep { .. } => StatusCode::BAD_REQUEST,
39            Self::InvalidJson { .. } => StatusCode::BAD_REQUEST,
40        }
41    }
42
43    /// Returns the error key used in JSON responses.
44    pub fn error_key(&self) -> &'static str {
45        match self {
46            Self::BodyTooLarge { .. } => "body_too_large",
47            Self::RequestTimeout { .. } => "request_timeout",
48            Self::InvalidContentType { .. } => "invalid_content_type",
49            Self::MissingHeader { .. } => "missing_header",
50            Self::JsonTooDeep { .. } => "json_too_deep",
51            Self::InvalidJson { .. } => "invalid_json",
52        }
53    }
54}
55
56/// Action to take after evaluating a violation through OnViolation policy.
57#[derive(Default)]
58pub enum ViolationAction {
59    /// Reject the request with the default error response.
60    #[default]
61    Reject,
62    /// Let the request through (ignored for Timeout violations).
63    Pass,
64    /// Respond with a fully custom HTTP response.
65    RespondWith(Response<String>),
66}
67
68/// Policy for handling violations.
69#[derive(Clone, Default)]
70pub enum OnViolation {
71    /// Return the appropriate 4xx/5xx response immediately.
72    #[default]
73    Reject,
74    /// Log the violation via tracing::warn but forward the request.
75    /// Does NOT apply to Timeout violations (no response to forward).
76    LogAndPass,
77    /// Custom callback. Must be Fn(&Violation) -> ViolationAction + Send + Sync + 'static.
78    Custom(Arc<dyn Fn(&Violation) -> ViolationAction + Send + Sync>),
79}
80
81impl OnViolation {
82    /// Create a custom violation handler from a closure.
83    pub fn custom<F>(f: F) -> Self
84    where
85        F: Fn(&Violation) -> ViolationAction + Send + Sync + 'static,
86    {
87        Self::Custom(Arc::new(f))
88    }
89}
90
91impl std::fmt::Debug for OnViolation {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self {
94            Self::Reject => write!(f, "OnViolation::Reject"),
95            Self::LogAndPass => write!(f, "OnViolation::LogAndPass"),
96            Self::Custom(_) => write!(f, "OnViolation::Custom(...)"),
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use http::StatusCode;
105
106    #[test]
107    fn violation_status_codes() {
108        assert_eq!(
109            Violation::BodyTooLarge {
110                max: 100,
111                received: 200
112            }
113            .status_code(),
114            StatusCode::PAYLOAD_TOO_LARGE
115        );
116        assert_eq!(
117            Violation::RequestTimeout { timeout_ms: 5000 }.status_code(),
118            StatusCode::GATEWAY_TIMEOUT
119        );
120        assert_eq!(
121            Violation::InvalidContentType {
122                received: "text/xml".into(),
123                allowed: vec!["application/json".into()],
124            }
125            .status_code(),
126            StatusCode::UNSUPPORTED_MEDIA_TYPE
127        );
128        assert_eq!(
129            Violation::MissingHeader {
130                header: "Authorization".into()
131            }
132            .status_code(),
133            StatusCode::BAD_REQUEST
134        );
135        assert_eq!(
136            Violation::JsonTooDeep {
137                max_depth: 32,
138                found_depth: 128
139            }
140            .status_code(),
141            StatusCode::BAD_REQUEST
142        );
143        assert_eq!(
144            Violation::InvalidJson {
145                detail: "unexpected EOF".into()
146            }
147            .status_code(),
148            StatusCode::BAD_REQUEST
149        );
150    }
151
152    #[test]
153    fn on_violation_default_is_reject() {
154        assert!(matches!(OnViolation::default(), OnViolation::Reject));
155    }
156
157    #[test]
158    fn violation_action_default_is_reject() {
159        assert!(matches!(
160            ViolationAction::default(),
161            ViolationAction::Reject
162        ));
163    }
164}