1use std::collections::HashMap;
6use std::fmt;
7use std::time::Duration;
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct Uri {
12 pub scheme: String,
14 pub host: String,
16 pub port: Option<u16>,
18 pub path: String,
20 pub query: Option<String>,
22 pub fragment: Option<String>,
24 pub userinfo: Option<String>,
26}
27
28impl Uri {
29 pub fn parse(s: &str) -> Result<Self, ProtocolError> {
31 #[cfg(feature = "url")]
32 {
33 use url::Url;
34 let url = Url::parse(s).map_err(|e| ProtocolError::InvalidUri(e.to_string()))?;
35
36 Ok(Uri {
37 scheme: url.scheme().to_string(),
38 host: url.host_str().unwrap_or("").to_string(),
39 port: url.port(),
40 path: url.path().to_string(),
41 query: url.query().map(|s| s.to_string()),
42 fragment: url.fragment().map(|s| s.to_string()),
43 userinfo: if url.username().is_empty() {
44 None
45 } else {
46 Some(format!(
47 "{}:{}",
48 url.username(),
49 url.password().unwrap_or("")
50 ))
51 },
52 })
53 }
54
55 #[cfg(not(feature = "url"))]
56 {
57 let mut uri = Uri {
59 scheme: String::new(),
60 host: String::new(),
61 port: None,
62 path: String::from("/"),
63 query: None,
64 fragment: None,
65 userinfo: None,
66 };
67
68 let s = s.trim();
69
70 if let Some(pos) = s.find("://") {
72 uri.scheme = s[..pos].to_string();
73 let rest = &s[pos + 3..];
74
75 let (authority, path_and_rest) = if let Some(pos) = rest.find('/') {
77 (&rest[..pos], &rest[pos..])
78 } else {
79 (rest, "/")
80 };
81
82 let host_port = if let Some(pos) = authority.find('@') {
84 uri.userinfo = Some(authority[..pos].to_string());
85 &authority[pos + 1..]
86 } else {
87 authority
88 };
89
90 if let Some(pos) = host_port.rfind(':') {
92 uri.host = host_port[..pos].to_string();
93 if let Ok(port) = host_port[pos + 1..].parse() {
94 uri.port = Some(port);
95 }
96 } else {
97 uri.host = host_port.to_string();
98 }
99
100 let (path_query, fragment) = if let Some(pos) = path_and_rest.find('#') {
102 uri.fragment = Some(path_and_rest[pos + 1..].to_string());
103 (&path_and_rest[..pos], Some(&path_and_rest[pos + 1..]))
104 } else {
105 (path_and_rest, None)
106 };
107
108 if let Some(pos) = path_query.find('?') {
109 uri.path = path_query[..pos].to_string();
110 uri.query = Some(path_query[pos + 1..].to_string());
111 } else {
112 uri.path = path_query.to_string();
113 }
114 } else {
115 return Err(ProtocolError::InvalidUri("Missing scheme".to_string()));
116 }
117
118 Ok(uri)
119 }
120 }
121
122 pub fn to_string(&self) -> String {
124 let mut s = format!("{}://", self.scheme);
125
126 if let Some(ref userinfo) = self.userinfo {
127 s.push_str(userinfo);
128 s.push('@');
129 }
130
131 s.push_str(&self.host);
132
133 if let Some(port) = self.port {
134 s.push(':');
135 s.push_str(&port.to_string());
136 }
137
138 s.push_str(&self.path);
139
140 if let Some(ref query) = self.query {
141 s.push('?');
142 s.push_str(query);
143 }
144
145 if let Some(ref fragment) = self.fragment {
146 s.push('#');
147 s.push_str(fragment);
148 }
149
150 s
151 }
152
153 pub fn authority(&self) -> String {
155 if let Some(port) = self.port {
156 format!("{}:{}", self.host, port)
157 } else {
158 self.host.clone()
159 }
160 }
161}
162
163impl fmt::Display for Uri {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 write!(f, "{}", self.to_string())
166 }
167}
168
169#[derive(Debug, Clone, Default)]
171pub struct Headers {
172 inner: HashMap<String, Vec<String>>,
173}
174
175impl Headers {
176 pub fn new() -> Self {
178 Headers {
179 inner: HashMap::new(),
180 }
181 }
182
183 pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
185 let key = key.into().to_lowercase();
186 let value = value.into();
187 self.inner.entry(key).or_insert_with(Vec::new).push(value);
188 self
189 }
190
191 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
193 let key = key.into().to_lowercase();
194 let value = value.into();
195 self.inner.insert(key, vec![value]);
196 self
197 }
198
199 pub fn get(&self, key: &str) -> Option<&str> {
201 self.inner
202 .get(&key.to_lowercase())
203 .and_then(|v| v.first().map(|s| s.as_str()))
204 }
205
206 pub fn get_all(&self, key: &str) -> Option<&[String]> {
208 self.inner.get(&key.to_lowercase()).map(|v| v.as_slice())
209 }
210
211 pub fn remove(&mut self, key: &str) -> Option<Vec<String>> {
213 self.inner.remove(&key.to_lowercase())
214 }
215
216 pub fn contains(&self, key: &str) -> bool {
218 self.inner.contains_key(&key.to_lowercase())
219 }
220
221 pub fn content_type(&self) -> Option<&str> {
223 self.get("content-type")
224 }
225
226 pub fn content_length(&self) -> Option<u64> {
228 self.get("content-length").and_then(|v| v.parse().ok())
229 }
230
231 pub fn iter(&self) -> impl Iterator<Item = (&String, &Vec<String>)> {
233 self.inner.iter()
234 }
235
236 pub fn len(&self) -> usize {
238 self.inner.len()
239 }
240
241 pub fn is_empty(&self) -> bool {
243 self.inner.is_empty()
244 }
245}
246
247impl IntoIterator for Headers {
248 type Item = (String, Vec<String>);
249 type IntoIter = std::collections::hash_map::IntoIter<String, Vec<String>>;
250
251 fn into_iter(self) -> Self::IntoIter {
252 self.inner.into_iter()
253 }
254}
255
256impl<K: Into<String>, V: Into<String>> FromIterator<(K, V)> for Headers {
257 fn from_iter<I: IntoIterator<Item = (K, V)>>(iter: I) -> Self {
258 let mut headers = Headers::new();
259 for (k, v) in iter {
260 headers.insert(k, v);
261 }
262 headers
263 }
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
268pub struct StatusCode(u16);
269
270impl StatusCode {
271 pub const CONTINUE: StatusCode = StatusCode(100);
273 pub const SWITCHING_PROTOCOLS: StatusCode = StatusCode(101);
274
275 pub const OK: StatusCode = StatusCode(200);
277 pub const CREATED: StatusCode = StatusCode(201);
278 pub const ACCEPTED: StatusCode = StatusCode(202);
279 pub const NO_CONTENT: StatusCode = StatusCode(204);
280
281 pub const MOVED_PERMANENTLY: StatusCode = StatusCode(301);
283 pub const FOUND: StatusCode = StatusCode(302);
284 pub const SEE_OTHER: StatusCode = StatusCode(303);
285 pub const NOT_MODIFIED: StatusCode = StatusCode(304);
286 pub const TEMPORARY_REDIRECT: StatusCode = StatusCode(307);
287 pub const PERMANENT_REDIRECT: StatusCode = StatusCode(308);
288
289 pub const BAD_REQUEST: StatusCode = StatusCode(400);
291 pub const UNAUTHORIZED: StatusCode = StatusCode(401);
292 pub const FORBIDDEN: StatusCode = StatusCode(403);
293 pub const NOT_FOUND: StatusCode = StatusCode(404);
294 pub const METHOD_NOT_ALLOWED: StatusCode = StatusCode(405);
295 pub const CONFLICT: StatusCode = StatusCode(409);
296 pub const GONE: StatusCode = StatusCode(410);
297 pub const UNPROCESSABLE_ENTITY: StatusCode = StatusCode(422);
298 pub const TOO_MANY_REQUESTS: StatusCode = StatusCode(429);
299
300 pub const INTERNAL_SERVER_ERROR: StatusCode = StatusCode(500);
302 pub const NOT_IMPLEMENTED: StatusCode = StatusCode(501);
303 pub const BAD_GATEWAY: StatusCode = StatusCode(502);
304 pub const SERVICE_UNAVAILABLE: StatusCode = StatusCode(503);
305 pub const GATEWAY_TIMEOUT: StatusCode = StatusCode(504);
306
307 pub fn from_u16(code: u16) -> Self {
309 StatusCode(code)
310 }
311
312 pub fn as_u16(&self) -> u16 {
314 self.0
315 }
316
317 pub fn is_informational(&self) -> bool {
319 self.0 >= 100 && self.0 < 200
320 }
321
322 pub fn is_success(&self) -> bool {
324 self.0 >= 200 && self.0 < 300
325 }
326
327 pub fn is_redirection(&self) -> bool {
329 self.0 >= 300 && self.0 < 400
330 }
331
332 pub fn is_client_error(&self) -> bool {
334 self.0 >= 400 && self.0 < 500
335 }
336
337 pub fn is_server_error(&self) -> bool {
339 self.0 >= 500 && self.0 < 600
340 }
341
342 pub fn reason_phrase(&self) -> &'static str {
344 match self.0 {
345 100 => "Continue",
346 101 => "Switching Protocols",
347 200 => "OK",
348 201 => "Created",
349 202 => "Accepted",
350 204 => "No Content",
351 301 => "Moved Permanently",
352 302 => "Found",
353 303 => "See Other",
354 304 => "Not Modified",
355 307 => "Temporary Redirect",
356 308 => "Permanent Redirect",
357 400 => "Bad Request",
358 401 => "Unauthorized",
359 403 => "Forbidden",
360 404 => "Not Found",
361 405 => "Method Not Allowed",
362 409 => "Conflict",
363 410 => "Gone",
364 422 => "Unprocessable Entity",
365 429 => "Too Many Requests",
366 500 => "Internal Server Error",
367 501 => "Not Implemented",
368 502 => "Bad Gateway",
369 503 => "Service Unavailable",
370 504 => "Gateway Timeout",
371 _ => "Unknown",
372 }
373 }
374}
375
376impl fmt::Display for StatusCode {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 write!(f, "{} {}", self.0, self.reason_phrase())
379 }
380}
381
382#[derive(Debug, Clone)]
384pub struct Timeout {
385 pub connect: Option<Duration>,
387 pub read: Option<Duration>,
389 pub write: Option<Duration>,
391 pub total: Option<Duration>,
393}
394
395impl Timeout {
396 pub fn new() -> Self {
398 Timeout {
399 connect: None,
400 read: None,
401 write: None,
402 total: None,
403 }
404 }
405
406 pub fn all(duration: Duration) -> Self {
408 Timeout {
409 connect: Some(duration),
410 read: Some(duration),
411 write: Some(duration),
412 total: Some(duration),
413 }
414 }
415
416 pub fn connect_timeout(mut self, duration: Duration) -> Self {
418 self.connect = Some(duration);
419 self
420 }
421
422 pub fn read_timeout(mut self, duration: Duration) -> Self {
424 self.read = Some(duration);
425 self
426 }
427
428 pub fn write_timeout(mut self, duration: Duration) -> Self {
430 self.write = Some(duration);
431 self
432 }
433
434 pub fn total_timeout(mut self, duration: Duration) -> Self {
436 self.total = Some(duration);
437 self
438 }
439}
440
441impl Default for Timeout {
442 fn default() -> Self {
443 Timeout::new()
444 }
445}
446
447#[derive(Debug, Clone)]
449pub enum BackoffStrategy {
450 Fixed(Duration),
452 Linear {
454 initial: Duration,
455 increment: Duration,
456 max: Option<Duration>,
457 },
458 Exponential {
460 initial: Duration,
461 factor: f64,
462 max: Option<Duration>,
463 },
464}
465
466impl BackoffStrategy {
467 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
469 match self {
470 BackoffStrategy::Fixed(d) => *d,
471 BackoffStrategy::Linear {
472 initial,
473 increment,
474 max,
475 } => {
476 let delay = *initial + (*increment * attempt);
477 max.map(|m| delay.min(m)).unwrap_or(delay)
478 }
479 BackoffStrategy::Exponential {
480 initial,
481 factor,
482 max,
483 } => {
484 let multiplier = factor.powi(attempt as i32);
485 let delay = initial.mul_f64(multiplier);
486 max.map(|m| delay.min(m)).unwrap_or(delay)
487 }
488 }
489 }
490}
491
492#[derive(Debug, Clone)]
494pub struct RetryConfig {
495 pub max_attempts: u32,
497 pub backoff: BackoffStrategy,
499 pub retry_on_status: Vec<StatusCode>,
501 pub retry_on_connection_error: bool,
503 pub retry_on_timeout: bool,
505}
506
507impl RetryConfig {
508 pub fn new(max_attempts: u32) -> Self {
510 RetryConfig {
511 max_attempts,
512 backoff: BackoffStrategy::Exponential {
513 initial: Duration::from_millis(100),
514 factor: 2.0,
515 max: Some(Duration::from_secs(30)),
516 },
517 retry_on_status: vec![
518 StatusCode::SERVICE_UNAVAILABLE,
519 StatusCode::GATEWAY_TIMEOUT,
520 StatusCode::TOO_MANY_REQUESTS,
521 ],
522 retry_on_connection_error: true,
523 retry_on_timeout: true,
524 }
525 }
526
527 pub fn backoff(mut self, strategy: BackoffStrategy) -> Self {
529 self.backoff = strategy;
530 self
531 }
532
533 pub fn retry_on(mut self, codes: Vec<StatusCode>) -> Self {
535 self.retry_on_status = codes;
536 self
537 }
538
539 pub fn should_retry_status(&self, status: StatusCode) -> bool {
541 self.retry_on_status.contains(&status)
542 }
543}
544
545impl Default for RetryConfig {
546 fn default() -> Self {
547 RetryConfig::new(3)
548 }
549}
550
551#[derive(Debug, Clone)]
553pub enum ProtocolError {
554 InvalidUri(String),
556 ConnectionFailed(String),
558 ConnectionTimeout,
560 ReadTimeout,
562 WriteTimeout,
564 RequestTimeout,
566 TlsError(String),
568 Protocol(String),
570 Serialization(String),
572 Deserialization(String),
574 Authentication(String),
576 Authorization(String),
578 RateLimited { retry_after: Option<Duration> },
580 NotFound(String),
582 ServerError(StatusCode, String),
584 ClientError(StatusCode, String),
586 Io(String),
588 ChannelClosed,
590 Cancelled,
592 Other(String),
594}
595
596impl fmt::Display for ProtocolError {
597 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
598 match self {
599 ProtocolError::InvalidUri(msg) => write!(f, "Invalid URI: {}", msg),
600 ProtocolError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
601 ProtocolError::ConnectionTimeout => write!(f, "Connection timeout"),
602 ProtocolError::ReadTimeout => write!(f, "Read timeout"),
603 ProtocolError::WriteTimeout => write!(f, "Write timeout"),
604 ProtocolError::RequestTimeout => write!(f, "Request timeout"),
605 ProtocolError::TlsError(msg) => write!(f, "TLS error: {}", msg),
606 ProtocolError::Protocol(msg) => write!(f, "Protocol error: {}", msg),
607 ProtocolError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
608 ProtocolError::Deserialization(msg) => write!(f, "Deserialization error: {}", msg),
609 ProtocolError::Authentication(msg) => write!(f, "Authentication error: {}", msg),
610 ProtocolError::Authorization(msg) => write!(f, "Authorization error: {}", msg),
611 ProtocolError::RateLimited { retry_after } => {
612 if let Some(d) = retry_after {
613 write!(f, "Rate limited, retry after {:?}", d)
614 } else {
615 write!(f, "Rate limited")
616 }
617 }
618 ProtocolError::NotFound(msg) => write!(f, "Not found: {}", msg),
619 ProtocolError::ServerError(code, msg) => write!(f, "Server error ({}): {}", code, msg),
620 ProtocolError::ClientError(code, msg) => write!(f, "Client error ({}): {}", code, msg),
621 ProtocolError::Io(msg) => write!(f, "IO error: {}", msg),
622 ProtocolError::ChannelClosed => write!(f, "Channel closed"),
623 ProtocolError::Cancelled => write!(f, "Operation cancelled"),
624 ProtocolError::Other(msg) => write!(f, "{}", msg),
625 }
626 }
627}
628
629impl std::error::Error for ProtocolError {}
630
631pub type ProtocolResult<T> = Result<T, ProtocolError>;
633
634#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
636pub enum Method {
637 GET,
638 POST,
639 PUT,
640 DELETE,
641 PATCH,
642 HEAD,
643 OPTIONS,
644 CONNECT,
645 TRACE,
646}
647
648impl Method {
649 pub fn as_str(&self) -> &'static str {
651 match self {
652 Method::GET => "GET",
653 Method::POST => "POST",
654 Method::PUT => "PUT",
655 Method::DELETE => "DELETE",
656 Method::PATCH => "PATCH",
657 Method::HEAD => "HEAD",
658 Method::OPTIONS => "OPTIONS",
659 Method::CONNECT => "CONNECT",
660 Method::TRACE => "TRACE",
661 }
662 }
663
664 pub fn is_idempotent(&self) -> bool {
666 matches!(
667 self,
668 Method::GET
669 | Method::HEAD
670 | Method::PUT
671 | Method::DELETE
672 | Method::OPTIONS
673 | Method::TRACE
674 )
675 }
676
677 pub fn is_safe(&self) -> bool {
679 matches!(
680 self,
681 Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
682 )
683 }
684}
685
686impl fmt::Display for Method {
687 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688 write!(f, "{}", self.as_str())
689 }
690}
691
692impl TryFrom<&str> for Method {
693 type Error = ProtocolError;
694
695 fn try_from(s: &str) -> Result<Self, Self::Error> {
696 match s.to_uppercase().as_str() {
697 "GET" => Ok(Method::GET),
698 "POST" => Ok(Method::POST),
699 "PUT" => Ok(Method::PUT),
700 "DELETE" => Ok(Method::DELETE),
701 "PATCH" => Ok(Method::PATCH),
702 "HEAD" => Ok(Method::HEAD),
703 "OPTIONS" => Ok(Method::OPTIONS),
704 "CONNECT" => Ok(Method::CONNECT),
705 "TRACE" => Ok(Method::TRACE),
706 _ => Err(ProtocolError::Protocol(format!(
707 "Unknown HTTP method: {}",
708 s
709 ))),
710 }
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717
718 #[test]
719 fn test_uri_parsing() {
720 let uri =
721 Uri::parse("https://user:pass@example.com:8080/path?query=value#fragment").unwrap();
722 assert_eq!(uri.scheme, "https");
723 assert_eq!(uri.host, "example.com");
724 assert_eq!(uri.port, Some(8080));
725 assert_eq!(uri.path, "/path");
726 assert_eq!(uri.query, Some("query=value".to_string()));
727 assert_eq!(uri.fragment, Some("fragment".to_string()));
728 assert_eq!(uri.userinfo, Some("user:pass".to_string()));
729 }
730
731 #[test]
732 fn test_headers() {
733 let mut headers = Headers::new();
734 headers.insert("Content-Type", "application/json");
735 headers.insert("X-Custom", "value1");
736 headers.insert("X-Custom", "value2");
737
738 assert_eq!(headers.get("content-type"), Some("application/json"));
739 assert_eq!(headers.get_all("x-custom").map(|v| v.len()), Some(2));
740 }
741
742 #[test]
743 fn test_status_code() {
744 assert!(StatusCode::OK.is_success());
745 assert!(StatusCode::NOT_FOUND.is_client_error());
746 assert!(StatusCode::INTERNAL_SERVER_ERROR.is_server_error());
747 assert!(StatusCode::MOVED_PERMANENTLY.is_redirection());
748 }
749
750 #[test]
751 fn test_backoff_strategy() {
752 let exp = BackoffStrategy::Exponential {
753 initial: Duration::from_millis(100),
754 factor: 2.0,
755 max: Some(Duration::from_secs(10)),
756 };
757
758 assert_eq!(exp.delay_for_attempt(0), Duration::from_millis(100));
759 assert_eq!(exp.delay_for_attempt(1), Duration::from_millis(200));
760 assert_eq!(exp.delay_for_attempt(2), Duration::from_millis(400));
761 assert_eq!(exp.delay_for_attempt(10), Duration::from_secs(10)); }
763}