Skip to main content

sigil_parser/protocol/
common.rs

1//! Common Protocol Types
2//!
3//! Shared abstractions used across all protocol implementations.
4
5use std::collections::HashMap;
6use std::fmt;
7use std::time::Duration;
8
9/// A parsed URI with all components
10#[derive(Debug, Clone, PartialEq)]
11pub struct Uri {
12    /// The scheme (e.g., "https", "grpc", "ws")
13    pub scheme: String,
14    /// The host (e.g., "example.com")
15    pub host: String,
16    /// The port number (optional)
17    pub port: Option<u16>,
18    /// The path (e.g., "/api/v1/users")
19    pub path: String,
20    /// The query string (optional, without leading ?)
21    pub query: Option<String>,
22    /// The fragment (optional, without leading #)
23    pub fragment: Option<String>,
24    /// User info for authentication (optional, "user:pass")
25    pub userinfo: Option<String>,
26}
27
28impl Uri {
29    /// Parse a URI string into components
30    pub fn parse(s: &str) -> Result<Self, ProtocolError> {
31        #[cfg(feature = "url")]
32        {
33            use url::Url;
34            let url = Url::parse(s).map_err(|e| ProtocolError::InvalidUri(e.to_string()))?;
35
36            Ok(Uri {
37                scheme: url.scheme().to_string(),
38                host: url.host_str().unwrap_or("").to_string(),
39                port: url.port(),
40                path: url.path().to_string(),
41                query: url.query().map(|s| s.to_string()),
42                fragment: url.fragment().map(|s| s.to_string()),
43                userinfo: if url.username().is_empty() {
44                    None
45                } else {
46                    Some(format!(
47                        "{}:{}",
48                        url.username(),
49                        url.password().unwrap_or("")
50                    ))
51                },
52            })
53        }
54
55        #[cfg(not(feature = "url"))]
56        {
57            // Basic fallback parser
58            let mut uri = Uri {
59                scheme: String::new(),
60                host: String::new(),
61                port: None,
62                path: String::from("/"),
63                query: None,
64                fragment: None,
65                userinfo: None,
66            };
67
68            let s = s.trim();
69
70            // Parse scheme
71            if let Some(pos) = s.find("://") {
72                uri.scheme = s[..pos].to_string();
73                let rest = &s[pos + 3..];
74
75                // Parse host and port
76                let (authority, path_and_rest) = if let Some(pos) = rest.find('/') {
77                    (&rest[..pos], &rest[pos..])
78                } else {
79                    (rest, "/")
80                };
81
82                // Parse userinfo
83                let host_port = if let Some(pos) = authority.find('@') {
84                    uri.userinfo = Some(authority[..pos].to_string());
85                    &authority[pos + 1..]
86                } else {
87                    authority
88                };
89
90                // Parse host and port
91                if let Some(pos) = host_port.rfind(':') {
92                    uri.host = host_port[..pos].to_string();
93                    if let Ok(port) = host_port[pos + 1..].parse() {
94                        uri.port = Some(port);
95                    }
96                } else {
97                    uri.host = host_port.to_string();
98                }
99
100                // Parse path, query, fragment
101                let (path_query, fragment) = if let Some(pos) = path_and_rest.find('#') {
102                    uri.fragment = Some(path_and_rest[pos + 1..].to_string());
103                    (&path_and_rest[..pos], Some(&path_and_rest[pos + 1..]))
104                } else {
105                    (path_and_rest, None)
106                };
107
108                if let Some(pos) = path_query.find('?') {
109                    uri.path = path_query[..pos].to_string();
110                    uri.query = Some(path_query[pos + 1..].to_string());
111                } else {
112                    uri.path = path_query.to_string();
113                }
114            } else {
115                return Err(ProtocolError::InvalidUri("Missing scheme".to_string()));
116            }
117
118            Ok(uri)
119        }
120    }
121
122    /// Reconstruct the full URI string
123    pub fn to_string(&self) -> String {
124        let mut s = format!("{}://", self.scheme);
125
126        if let Some(ref userinfo) = self.userinfo {
127            s.push_str(userinfo);
128            s.push('@');
129        }
130
131        s.push_str(&self.host);
132
133        if let Some(port) = self.port {
134            s.push(':');
135            s.push_str(&port.to_string());
136        }
137
138        s.push_str(&self.path);
139
140        if let Some(ref query) = self.query {
141            s.push('?');
142            s.push_str(query);
143        }
144
145        if let Some(ref fragment) = self.fragment {
146            s.push('#');
147            s.push_str(fragment);
148        }
149
150        s
151    }
152
153    /// Get the authority portion (host:port)
154    pub fn authority(&self) -> String {
155        if let Some(port) = self.port {
156            format!("{}:{}", self.host, port)
157        } else {
158            self.host.clone()
159        }
160    }
161}
162
163impl fmt::Display for Uri {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        write!(f, "{}", self.to_string())
166    }
167}
168
169/// HTTP-style headers collection
170#[derive(Debug, Clone, Default)]
171pub struct Headers {
172    inner: HashMap<String, Vec<String>>,
173}
174
175impl Headers {
176    /// Create empty headers
177    pub fn new() -> Self {
178        Headers {
179            inner: HashMap::new(),
180        }
181    }
182
183    /// Insert a header value (appends if key exists)
184    pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
185        let key = key.into().to_lowercase();
186        let value = value.into();
187        self.inner.entry(key).or_insert_with(Vec::new).push(value);
188        self
189    }
190
191    /// Set a header value (replaces existing)
192    pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
193        let key = key.into().to_lowercase();
194        let value = value.into();
195        self.inner.insert(key, vec![value]);
196        self
197    }
198
199    /// Get the first value for a header
200    pub fn get(&self, key: &str) -> Option<&str> {
201        self.inner
202            .get(&key.to_lowercase())
203            .and_then(|v| v.first().map(|s| s.as_str()))
204    }
205
206    /// Get all values for a header
207    pub fn get_all(&self, key: &str) -> Option<&[String]> {
208        self.inner.get(&key.to_lowercase()).map(|v| v.as_slice())
209    }
210
211    /// Remove a header
212    pub fn remove(&mut self, key: &str) -> Option<Vec<String>> {
213        self.inner.remove(&key.to_lowercase())
214    }
215
216    /// Check if header exists
217    pub fn contains(&self, key: &str) -> bool {
218        self.inner.contains_key(&key.to_lowercase())
219    }
220
221    /// Get content-type header
222    pub fn content_type(&self) -> Option<&str> {
223        self.get("content-type")
224    }
225
226    /// Get content-length header
227    pub fn content_length(&self) -> Option<u64> {
228        self.get("content-length").and_then(|v| v.parse().ok())
229    }
230
231    /// Iterate over all headers
232    pub fn iter(&self) -> impl Iterator<Item = (&String, &Vec<String>)> {
233        self.inner.iter()
234    }
235
236    /// Number of header keys
237    pub fn len(&self) -> usize {
238        self.inner.len()
239    }
240
241    /// Check if empty
242    pub fn is_empty(&self) -> bool {
243        self.inner.is_empty()
244    }
245}
246
247impl IntoIterator for Headers {
248    type Item = (String, Vec<String>);
249    type IntoIter = std::collections::hash_map::IntoIter<String, Vec<String>>;
250
251    fn into_iter(self) -> Self::IntoIter {
252        self.inner.into_iter()
253    }
254}
255
256impl<K: Into<String>, V: Into<String>> FromIterator<(K, V)> for Headers {
257    fn from_iter<I: IntoIterator<Item = (K, V)>>(iter: I) -> Self {
258        let mut headers = Headers::new();
259        for (k, v) in iter {
260            headers.insert(k, v);
261        }
262        headers
263    }
264}
265
266/// HTTP status codes
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
268pub struct StatusCode(u16);
269
270impl StatusCode {
271    // Informational
272    pub const CONTINUE: StatusCode = StatusCode(100);
273    pub const SWITCHING_PROTOCOLS: StatusCode = StatusCode(101);
274
275    // Success
276    pub const OK: StatusCode = StatusCode(200);
277    pub const CREATED: StatusCode = StatusCode(201);
278    pub const ACCEPTED: StatusCode = StatusCode(202);
279    pub const NO_CONTENT: StatusCode = StatusCode(204);
280
281    // Redirection
282    pub const MOVED_PERMANENTLY: StatusCode = StatusCode(301);
283    pub const FOUND: StatusCode = StatusCode(302);
284    pub const SEE_OTHER: StatusCode = StatusCode(303);
285    pub const NOT_MODIFIED: StatusCode = StatusCode(304);
286    pub const TEMPORARY_REDIRECT: StatusCode = StatusCode(307);
287    pub const PERMANENT_REDIRECT: StatusCode = StatusCode(308);
288
289    // Client Error
290    pub const BAD_REQUEST: StatusCode = StatusCode(400);
291    pub const UNAUTHORIZED: StatusCode = StatusCode(401);
292    pub const FORBIDDEN: StatusCode = StatusCode(403);
293    pub const NOT_FOUND: StatusCode = StatusCode(404);
294    pub const METHOD_NOT_ALLOWED: StatusCode = StatusCode(405);
295    pub const CONFLICT: StatusCode = StatusCode(409);
296    pub const GONE: StatusCode = StatusCode(410);
297    pub const UNPROCESSABLE_ENTITY: StatusCode = StatusCode(422);
298    pub const TOO_MANY_REQUESTS: StatusCode = StatusCode(429);
299
300    // Server Error
301    pub const INTERNAL_SERVER_ERROR: StatusCode = StatusCode(500);
302    pub const NOT_IMPLEMENTED: StatusCode = StatusCode(501);
303    pub const BAD_GATEWAY: StatusCode = StatusCode(502);
304    pub const SERVICE_UNAVAILABLE: StatusCode = StatusCode(503);
305    pub const GATEWAY_TIMEOUT: StatusCode = StatusCode(504);
306
307    /// Create a status code from a number
308    pub fn from_u16(code: u16) -> Self {
309        StatusCode(code)
310    }
311
312    /// Get the numeric code
313    pub fn as_u16(&self) -> u16 {
314        self.0
315    }
316
317    /// Check if this is an informational status (1xx)
318    pub fn is_informational(&self) -> bool {
319        self.0 >= 100 && self.0 < 200
320    }
321
322    /// Check if this is a success status (2xx)
323    pub fn is_success(&self) -> bool {
324        self.0 >= 200 && self.0 < 300
325    }
326
327    /// Check if this is a redirection status (3xx)
328    pub fn is_redirection(&self) -> bool {
329        self.0 >= 300 && self.0 < 400
330    }
331
332    /// Check if this is a client error status (4xx)
333    pub fn is_client_error(&self) -> bool {
334        self.0 >= 400 && self.0 < 500
335    }
336
337    /// Check if this is a server error status (5xx)
338    pub fn is_server_error(&self) -> bool {
339        self.0 >= 500 && self.0 < 600
340    }
341
342    /// Get the reason phrase for this status code
343    pub fn reason_phrase(&self) -> &'static str {
344        match self.0 {
345            100 => "Continue",
346            101 => "Switching Protocols",
347            200 => "OK",
348            201 => "Created",
349            202 => "Accepted",
350            204 => "No Content",
351            301 => "Moved Permanently",
352            302 => "Found",
353            303 => "See Other",
354            304 => "Not Modified",
355            307 => "Temporary Redirect",
356            308 => "Permanent Redirect",
357            400 => "Bad Request",
358            401 => "Unauthorized",
359            403 => "Forbidden",
360            404 => "Not Found",
361            405 => "Method Not Allowed",
362            409 => "Conflict",
363            410 => "Gone",
364            422 => "Unprocessable Entity",
365            429 => "Too Many Requests",
366            500 => "Internal Server Error",
367            501 => "Not Implemented",
368            502 => "Bad Gateway",
369            503 => "Service Unavailable",
370            504 => "Gateway Timeout",
371            _ => "Unknown",
372        }
373    }
374}
375
376impl fmt::Display for StatusCode {
377    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378        write!(f, "{} {}", self.0, self.reason_phrase())
379    }
380}
381
382/// Timeout configuration for protocol operations
383#[derive(Debug, Clone)]
384pub struct Timeout {
385    /// Connection timeout
386    pub connect: Option<Duration>,
387    /// Read timeout
388    pub read: Option<Duration>,
389    /// Write timeout
390    pub write: Option<Duration>,
391    /// Total operation timeout
392    pub total: Option<Duration>,
393}
394
395impl Timeout {
396    /// Create a new timeout configuration
397    pub fn new() -> Self {
398        Timeout {
399            connect: None,
400            read: None,
401            write: None,
402            total: None,
403        }
404    }
405
406    /// Set all timeouts to the same duration
407    pub fn all(duration: Duration) -> Self {
408        Timeout {
409            connect: Some(duration),
410            read: Some(duration),
411            write: Some(duration),
412            total: Some(duration),
413        }
414    }
415
416    /// Set connection timeout
417    pub fn connect_timeout(mut self, duration: Duration) -> Self {
418        self.connect = Some(duration);
419        self
420    }
421
422    /// Set read timeout
423    pub fn read_timeout(mut self, duration: Duration) -> Self {
424        self.read = Some(duration);
425        self
426    }
427
428    /// Set write timeout
429    pub fn write_timeout(mut self, duration: Duration) -> Self {
430        self.write = Some(duration);
431        self
432    }
433
434    /// Set total timeout
435    pub fn total_timeout(mut self, duration: Duration) -> Self {
436        self.total = Some(duration);
437        self
438    }
439}
440
441impl Default for Timeout {
442    fn default() -> Self {
443        Timeout::new()
444    }
445}
446
447/// Backoff strategy for retries
448#[derive(Debug, Clone)]
449pub enum BackoffStrategy {
450    /// Fixed delay between retries
451    Fixed(Duration),
452    /// Linear increase in delay
453    Linear {
454        initial: Duration,
455        increment: Duration,
456        max: Option<Duration>,
457    },
458    /// Exponential increase in delay
459    Exponential {
460        initial: Duration,
461        factor: f64,
462        max: Option<Duration>,
463    },
464}
465
466impl BackoffStrategy {
467    /// Calculate the delay for a given attempt number (0-indexed)
468    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
469        match self {
470            BackoffStrategy::Fixed(d) => *d,
471            BackoffStrategy::Linear {
472                initial,
473                increment,
474                max,
475            } => {
476                let delay = *initial + (*increment * attempt);
477                max.map(|m| delay.min(m)).unwrap_or(delay)
478            }
479            BackoffStrategy::Exponential {
480                initial,
481                factor,
482                max,
483            } => {
484                let multiplier = factor.powi(attempt as i32);
485                let delay = initial.mul_f64(multiplier);
486                max.map(|m| delay.min(m)).unwrap_or(delay)
487            }
488        }
489    }
490}
491
492/// Retry configuration
493#[derive(Debug, Clone)]
494pub struct RetryConfig {
495    /// Maximum number of attempts (including initial)
496    pub max_attempts: u32,
497    /// Backoff strategy
498    pub backoff: BackoffStrategy,
499    /// Status codes to retry on
500    pub retry_on_status: Vec<StatusCode>,
501    /// Whether to retry on connection errors
502    pub retry_on_connection_error: bool,
503    /// Whether to retry on timeout
504    pub retry_on_timeout: bool,
505}
506
507impl RetryConfig {
508    /// Create a new retry configuration
509    pub fn new(max_attempts: u32) -> Self {
510        RetryConfig {
511            max_attempts,
512            backoff: BackoffStrategy::Exponential {
513                initial: Duration::from_millis(100),
514                factor: 2.0,
515                max: Some(Duration::from_secs(30)),
516            },
517            retry_on_status: vec![
518                StatusCode::SERVICE_UNAVAILABLE,
519                StatusCode::GATEWAY_TIMEOUT,
520                StatusCode::TOO_MANY_REQUESTS,
521            ],
522            retry_on_connection_error: true,
523            retry_on_timeout: true,
524        }
525    }
526
527    /// Set the backoff strategy
528    pub fn backoff(mut self, strategy: BackoffStrategy) -> Self {
529        self.backoff = strategy;
530        self
531    }
532
533    /// Set status codes to retry on
534    pub fn retry_on(mut self, codes: Vec<StatusCode>) -> Self {
535        self.retry_on_status = codes;
536        self
537    }
538
539    /// Check if a status code should be retried
540    pub fn should_retry_status(&self, status: StatusCode) -> bool {
541        self.retry_on_status.contains(&status)
542    }
543}
544
545impl Default for RetryConfig {
546    fn default() -> Self {
547        RetryConfig::new(3)
548    }
549}
550
551/// Protocol error types
552#[derive(Debug, Clone)]
553pub enum ProtocolError {
554    /// Invalid URI format
555    InvalidUri(String),
556    /// Connection failed
557    ConnectionFailed(String),
558    /// Connection timeout
559    ConnectionTimeout,
560    /// Read timeout
561    ReadTimeout,
562    /// Write timeout
563    WriteTimeout,
564    /// Request timeout
565    RequestTimeout,
566    /// TLS/SSL error
567    TlsError(String),
568    /// Protocol-specific error
569    Protocol(String),
570    /// Serialization error
571    Serialization(String),
572    /// Deserialization error
573    Deserialization(String),
574    /// Authentication error
575    Authentication(String),
576    /// Authorization error
577    Authorization(String),
578    /// Rate limited
579    RateLimited { retry_after: Option<Duration> },
580    /// Resource not found
581    NotFound(String),
582    /// Server error
583    ServerError(StatusCode, String),
584    /// Client error
585    ClientError(StatusCode, String),
586    /// IO error
587    Io(String),
588    /// Channel closed
589    ChannelClosed,
590    /// Operation cancelled
591    Cancelled,
592    /// Other error
593    Other(String),
594}
595
596impl fmt::Display for ProtocolError {
597    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
598        match self {
599            ProtocolError::InvalidUri(msg) => write!(f, "Invalid URI: {}", msg),
600            ProtocolError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
601            ProtocolError::ConnectionTimeout => write!(f, "Connection timeout"),
602            ProtocolError::ReadTimeout => write!(f, "Read timeout"),
603            ProtocolError::WriteTimeout => write!(f, "Write timeout"),
604            ProtocolError::RequestTimeout => write!(f, "Request timeout"),
605            ProtocolError::TlsError(msg) => write!(f, "TLS error: {}", msg),
606            ProtocolError::Protocol(msg) => write!(f, "Protocol error: {}", msg),
607            ProtocolError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
608            ProtocolError::Deserialization(msg) => write!(f, "Deserialization error: {}", msg),
609            ProtocolError::Authentication(msg) => write!(f, "Authentication error: {}", msg),
610            ProtocolError::Authorization(msg) => write!(f, "Authorization error: {}", msg),
611            ProtocolError::RateLimited { retry_after } => {
612                if let Some(d) = retry_after {
613                    write!(f, "Rate limited, retry after {:?}", d)
614                } else {
615                    write!(f, "Rate limited")
616                }
617            }
618            ProtocolError::NotFound(msg) => write!(f, "Not found: {}", msg),
619            ProtocolError::ServerError(code, msg) => write!(f, "Server error ({}): {}", code, msg),
620            ProtocolError::ClientError(code, msg) => write!(f, "Client error ({}): {}", code, msg),
621            ProtocolError::Io(msg) => write!(f, "IO error: {}", msg),
622            ProtocolError::ChannelClosed => write!(f, "Channel closed"),
623            ProtocolError::Cancelled => write!(f, "Operation cancelled"),
624            ProtocolError::Other(msg) => write!(f, "{}", msg),
625        }
626    }
627}
628
629impl std::error::Error for ProtocolError {}
630
631/// Result type for protocol operations
632pub type ProtocolResult<T> = Result<T, ProtocolError>;
633
634/// HTTP Methods
635#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
636pub enum Method {
637    GET,
638    POST,
639    PUT,
640    DELETE,
641    PATCH,
642    HEAD,
643    OPTIONS,
644    CONNECT,
645    TRACE,
646}
647
648impl Method {
649    /// Get the method name as a string
650    pub fn as_str(&self) -> &'static str {
651        match self {
652            Method::GET => "GET",
653            Method::POST => "POST",
654            Method::PUT => "PUT",
655            Method::DELETE => "DELETE",
656            Method::PATCH => "PATCH",
657            Method::HEAD => "HEAD",
658            Method::OPTIONS => "OPTIONS",
659            Method::CONNECT => "CONNECT",
660            Method::TRACE => "TRACE",
661        }
662    }
663
664    /// Check if the method is idempotent
665    pub fn is_idempotent(&self) -> bool {
666        matches!(
667            self,
668            Method::GET
669                | Method::HEAD
670                | Method::PUT
671                | Method::DELETE
672                | Method::OPTIONS
673                | Method::TRACE
674        )
675    }
676
677    /// Check if the method is safe (no side effects)
678    pub fn is_safe(&self) -> bool {
679        matches!(
680            self,
681            Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
682        )
683    }
684}
685
686impl fmt::Display for Method {
687    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688        write!(f, "{}", self.as_str())
689    }
690}
691
692impl TryFrom<&str> for Method {
693    type Error = ProtocolError;
694
695    fn try_from(s: &str) -> Result<Self, Self::Error> {
696        match s.to_uppercase().as_str() {
697            "GET" => Ok(Method::GET),
698            "POST" => Ok(Method::POST),
699            "PUT" => Ok(Method::PUT),
700            "DELETE" => Ok(Method::DELETE),
701            "PATCH" => Ok(Method::PATCH),
702            "HEAD" => Ok(Method::HEAD),
703            "OPTIONS" => Ok(Method::OPTIONS),
704            "CONNECT" => Ok(Method::CONNECT),
705            "TRACE" => Ok(Method::TRACE),
706            _ => Err(ProtocolError::Protocol(format!(
707                "Unknown HTTP method: {}",
708                s
709            ))),
710        }
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    #[test]
719    fn test_uri_parsing() {
720        let uri =
721            Uri::parse("https://user:pass@example.com:8080/path?query=value#fragment").unwrap();
722        assert_eq!(uri.scheme, "https");
723        assert_eq!(uri.host, "example.com");
724        assert_eq!(uri.port, Some(8080));
725        assert_eq!(uri.path, "/path");
726        assert_eq!(uri.query, Some("query=value".to_string()));
727        assert_eq!(uri.fragment, Some("fragment".to_string()));
728        assert_eq!(uri.userinfo, Some("user:pass".to_string()));
729    }
730
731    #[test]
732    fn test_headers() {
733        let mut headers = Headers::new();
734        headers.insert("Content-Type", "application/json");
735        headers.insert("X-Custom", "value1");
736        headers.insert("X-Custom", "value2");
737
738        assert_eq!(headers.get("content-type"), Some("application/json"));
739        assert_eq!(headers.get_all("x-custom").map(|v| v.len()), Some(2));
740    }
741
742    #[test]
743    fn test_status_code() {
744        assert!(StatusCode::OK.is_success());
745        assert!(StatusCode::NOT_FOUND.is_client_error());
746        assert!(StatusCode::INTERNAL_SERVER_ERROR.is_server_error());
747        assert!(StatusCode::MOVED_PERMANENTLY.is_redirection());
748    }
749
750    #[test]
751    fn test_backoff_strategy() {
752        let exp = BackoffStrategy::Exponential {
753            initial: Duration::from_millis(100),
754            factor: 2.0,
755            max: Some(Duration::from_secs(10)),
756        };
757
758        assert_eq!(exp.delay_for_attempt(0), Duration::from_millis(100));
759        assert_eq!(exp.delay_for_attempt(1), Duration::from_millis(200));
760        assert_eq!(exp.delay_for_attempt(2), Duration::from_millis(400));
761        assert_eq!(exp.delay_for_attempt(10), Duration::from_secs(10)); // Capped at max
762    }
763}