turul_http_mcp_server/middleware/
error.rs1use std::fmt;
4
5pub mod error_codes {
10 pub const UNAUTHENTICATED: i64 = -32001;
12 pub const UNAUTHORIZED: i64 = -32002;
14 pub const RATE_LIMIT_EXCEEDED: i64 = -32003;
16 pub const INVALID_REQUEST: i64 = -32600;
18 pub const INTERNAL_ERROR: i64 = -32603;
20}
21
22#[derive(Debug, Clone, PartialEq)]
78pub enum MiddlewareError {
79 Unauthenticated(String),
81
82 Unauthorized(String),
84
85 RateLimitExceeded {
87 message: String,
89 retry_after: Option<u64>,
91 },
92
93 InvalidRequest(String),
95
96 Internal(String),
98
99 Custom {
101 code: String,
103 message: String,
105 },
106
107 HttpChallenge {
115 status: u16,
117 www_authenticate: String,
119 body: Option<String>,
121 },
122}
123
124impl fmt::Display for MiddlewareError {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 match self {
127 Self::Unauthenticated(msg) => write!(f, "Authentication required: {}", msg),
128 Self::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
129 Self::RateLimitExceeded {
130 message,
131 retry_after,
132 } => {
133 if let Some(seconds) = retry_after {
134 write!(f, "{} (retry after {} seconds)", message, seconds)
135 } else {
136 write!(f, "{}", message)
137 }
138 }
139 Self::InvalidRequest(msg) => write!(f, "Invalid request: {}", msg),
140 Self::Internal(msg) => write!(f, "Internal middleware error: {}", msg),
141 Self::Custom { code, message } => write!(f, "{}: {}", code, message),
142 Self::HttpChallenge {
143 status,
144 www_authenticate,
145 ..
146 } => write!(f, "HTTP {} WWW-Authenticate: {}", status, www_authenticate),
147 }
148 }
149}
150
151impl std::error::Error for MiddlewareError {}
152
153impl MiddlewareError {
154 pub fn unauthenticated(msg: impl Into<String>) -> Self {
156 Self::Unauthenticated(msg.into())
157 }
158
159 pub fn unauthorized(msg: impl Into<String>) -> Self {
161 Self::Unauthorized(msg.into())
162 }
163
164 pub fn rate_limit(msg: impl Into<String>, retry_after: Option<u64>) -> Self {
166 Self::RateLimitExceeded {
167 message: msg.into(),
168 retry_after,
169 }
170 }
171
172 pub fn invalid_request(msg: impl Into<String>) -> Self {
174 Self::InvalidRequest(msg.into())
175 }
176
177 pub fn internal(msg: impl Into<String>) -> Self {
179 Self::Internal(msg.into())
180 }
181
182 pub fn custom(code: impl Into<String>, message: impl Into<String>) -> Self {
184 Self::Custom {
185 code: code.into(),
186 message: message.into(),
187 }
188 }
189
190 pub fn http_challenge(status: u16, www_authenticate: impl Into<String>) -> Self {
194 Self::HttpChallenge {
195 status,
196 www_authenticate: www_authenticate.into(),
197 body: None,
198 }
199 }
200
201 pub fn http_challenge_with_body(
203 status: u16,
204 www_authenticate: impl Into<String>,
205 body: impl Into<String>,
206 ) -> Self {
207 Self::HttpChallenge {
208 status,
209 www_authenticate: www_authenticate.into(),
210 body: Some(body.into()),
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_error_display() {
221 let err = MiddlewareError::unauthenticated("Missing token");
222 assert_eq!(err.to_string(), "Authentication required: Missing token");
223
224 let err = MiddlewareError::unauthorized("Insufficient permissions");
225 assert_eq!(err.to_string(), "Unauthorized: Insufficient permissions");
226
227 let err = MiddlewareError::rate_limit("Too many requests", Some(60));
228 assert_eq!(
229 err.to_string(),
230 "Too many requests (retry after 60 seconds)"
231 );
232
233 let err = MiddlewareError::rate_limit("Too many requests", None);
234 assert_eq!(err.to_string(), "Too many requests");
235
236 let err = MiddlewareError::invalid_request("Malformed params");
237 assert_eq!(err.to_string(), "Invalid request: Malformed params");
238
239 let err = MiddlewareError::internal("Database connection failed");
240 assert_eq!(
241 err.to_string(),
242 "Internal middleware error: Database connection failed"
243 );
244
245 let err = MiddlewareError::custom("CUSTOM_ERROR", "Something went wrong");
246 assert_eq!(err.to_string(), "CUSTOM_ERROR: Something went wrong");
247 }
248
249 #[test]
250 fn test_error_equality() {
251 let err1 = MiddlewareError::unauthenticated("test");
252 let err2 = MiddlewareError::unauthenticated("test");
253 assert_eq!(err1, err2);
254
255 let err3 = MiddlewareError::rate_limit("test", Some(60));
256 let err4 = MiddlewareError::rate_limit("test", Some(60));
257 assert_eq!(err3, err4);
258 }
259
260 #[test]
261 fn test_http_challenge_variant_display() {
262 let err = MiddlewareError::http_challenge(401, "Bearer realm=\"mcp\"");
263 assert_eq!(
264 err.to_string(),
265 "HTTP 401 WWW-Authenticate: Bearer realm=\"mcp\""
266 );
267
268 let err = MiddlewareError::http_challenge(403, "Bearer error=\"insufficient_scope\"");
269 assert_eq!(
270 err.to_string(),
271 "HTTP 403 WWW-Authenticate: Bearer error=\"insufficient_scope\""
272 );
273 }
274
275 #[test]
276 fn test_http_challenge_constructor() {
277 let err = MiddlewareError::http_challenge(401, "Bearer realm=\"mcp\"");
278 match &err {
279 MiddlewareError::HttpChallenge {
280 status,
281 www_authenticate,
282 body,
283 } => {
284 assert_eq!(*status, 401);
285 assert_eq!(www_authenticate, "Bearer realm=\"mcp\"");
286 assert!(body.is_none());
287 }
288 _ => panic!("Expected HttpChallenge variant"),
289 }
290
291 let err_with_body = MiddlewareError::http_challenge_with_body(
292 401,
293 "Bearer realm=\"mcp\"",
294 r#"{"error":"unauthorized"}"#,
295 );
296 match &err_with_body {
297 MiddlewareError::HttpChallenge {
298 status,
299 www_authenticate,
300 body,
301 } => {
302 assert_eq!(*status, 401);
303 assert_eq!(www_authenticate, "Bearer realm=\"mcp\"");
304 assert_eq!(body.as_deref(), Some(r#"{"error":"unauthorized"}"#));
305 }
306 _ => panic!("Expected HttpChallenge variant"),
307 }
308 }
309
310 #[test]
311 fn test_http_challenge_roundtrip_equality() {
312 let err1 = MiddlewareError::http_challenge(401, "Bearer realm=\"mcp\"");
313 let err2 = MiddlewareError::http_challenge(401, "Bearer realm=\"mcp\"");
314 assert_eq!(err1, err2);
315
316 let err3 = MiddlewareError::http_challenge(401, "Bearer realm=\"mcp\"");
317 let err4 = MiddlewareError::http_challenge(403, "Bearer realm=\"mcp\"");
318 assert_ne!(err3, err4);
319 }
320}