tower_request_guard/
violation.rs1use http::{Response, StatusCode};
2use std::sync::Arc;
3
4#[derive(Debug, Clone)]
6pub enum Violation {
7 BodyTooLarge {
8 max: u64,
9 received: u64,
10 },
11 RequestTimeout {
12 timeout_ms: u64,
13 },
14 InvalidContentType {
15 received: String,
16 allowed: Vec<String>,
17 },
18 MissingHeader {
19 header: String,
20 },
21 JsonTooDeep {
22 max_depth: u32,
23 found_depth: u32,
24 },
25 InvalidJson {
26 detail: String,
27 },
28}
29
30impl Violation {
31 pub fn status_code(&self) -> StatusCode {
33 match self {
34 Self::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
35 Self::RequestTimeout { .. } => StatusCode::GATEWAY_TIMEOUT,
36 Self::InvalidContentType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE,
37 Self::MissingHeader { .. } => StatusCode::BAD_REQUEST,
38 Self::JsonTooDeep { .. } => StatusCode::BAD_REQUEST,
39 Self::InvalidJson { .. } => StatusCode::BAD_REQUEST,
40 }
41 }
42
43 pub fn error_key(&self) -> &'static str {
45 match self {
46 Self::BodyTooLarge { .. } => "body_too_large",
47 Self::RequestTimeout { .. } => "request_timeout",
48 Self::InvalidContentType { .. } => "invalid_content_type",
49 Self::MissingHeader { .. } => "missing_header",
50 Self::JsonTooDeep { .. } => "json_too_deep",
51 Self::InvalidJson { .. } => "invalid_json",
52 }
53 }
54}
55
56#[derive(Default)]
58pub enum ViolationAction {
59 #[default]
61 Reject,
62 Pass,
64 RespondWith(Response<String>),
66}
67
68#[derive(Clone, Default)]
70pub enum OnViolation {
71 #[default]
73 Reject,
74 LogAndPass,
77 Custom(Arc<dyn Fn(&Violation) -> ViolationAction + Send + Sync>),
79}
80
81impl OnViolation {
82 pub fn custom<F>(f: F) -> Self
84 where
85 F: Fn(&Violation) -> ViolationAction + Send + Sync + 'static,
86 {
87 Self::Custom(Arc::new(f))
88 }
89}
90
91impl std::fmt::Debug for OnViolation {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::Reject => write!(f, "OnViolation::Reject"),
95 Self::LogAndPass => write!(f, "OnViolation::LogAndPass"),
96 Self::Custom(_) => write!(f, "OnViolation::Custom(...)"),
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use http::StatusCode;
105
106 #[test]
107 fn violation_status_codes() {
108 assert_eq!(
109 Violation::BodyTooLarge {
110 max: 100,
111 received: 200
112 }
113 .status_code(),
114 StatusCode::PAYLOAD_TOO_LARGE
115 );
116 assert_eq!(
117 Violation::RequestTimeout { timeout_ms: 5000 }.status_code(),
118 StatusCode::GATEWAY_TIMEOUT
119 );
120 assert_eq!(
121 Violation::InvalidContentType {
122 received: "text/xml".into(),
123 allowed: vec!["application/json".into()],
124 }
125 .status_code(),
126 StatusCode::UNSUPPORTED_MEDIA_TYPE
127 );
128 assert_eq!(
129 Violation::MissingHeader {
130 header: "Authorization".into()
131 }
132 .status_code(),
133 StatusCode::BAD_REQUEST
134 );
135 assert_eq!(
136 Violation::JsonTooDeep {
137 max_depth: 32,
138 found_depth: 128
139 }
140 .status_code(),
141 StatusCode::BAD_REQUEST
142 );
143 assert_eq!(
144 Violation::InvalidJson {
145 detail: "unexpected EOF".into()
146 }
147 .status_code(),
148 StatusCode::BAD_REQUEST
149 );
150 }
151
152 #[test]
153 fn on_violation_default_is_reject() {
154 assert!(matches!(OnViolation::default(), OnViolation::Reject));
155 }
156
157 #[test]
158 fn violation_action_default_is_reject() {
159 assert!(matches!(
160 ViolationAction::default(),
161 ViolationAction::Reject
162 ));
163 }
164}