Skip to main content

tiny_proxy/config/
parser.rs

1use crate::config::address::{extract_hostname, resolve_listen_addr};
2use crate::config::{Config, Directive, SiteConfig};
3use crate::error::ProxyError;
4use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::str::FromStr;
7
8#[derive(Debug)]
9struct PendingBlock {
10    directive_type: String,
11    args: Vec<String>,
12    // Timeout settings for reverse_proxy blocks (in seconds)
13    connect_timeout: Option<u64>,
14    read_timeout: Option<u64>,
15}
16
17/// Parse a human-readable duration string into seconds.
18///
19/// Supported formats:
20/// - Plain number: `"30"` → 30 seconds
21/// - Seconds: `"30s"` → 30
22/// - Minutes: `"5m"` → 300
23/// - Hours: `"2h"` → 7200
24/// - Days: `"1d"` → 86400
25fn parse_duration(s: &str) -> Result<u64, ProxyError> {
26    let s = s.trim();
27    if s.is_empty() {
28        return Err(ProxyError::Parse("Empty duration value".to_string()));
29    }
30
31    // Try plain number first (seconds)
32    if let Ok(secs) = s.parse::<u64>() {
33        return Ok(secs);
34    }
35
36    // Parse with suffix
37    let (num_part, multiplier) = if let Some(n) = s.strip_suffix('s') {
38        (n, 1u64)
39    } else if let Some(n) = s.strip_suffix('m') {
40        (n, 60u64)
41    } else if let Some(n) = s.strip_suffix('h') {
42        (n, 3600u64)
43    } else if let Some(n) = s.strip_suffix('d') {
44        (n, 86400u64)
45    } else {
46        return Err(ProxyError::Parse(format!(
47            "Invalid duration '{}'. Use a plain number or Ns/Nm/Nh/Nd",
48            s
49        )));
50    };
51
52    let value: u64 = num_part
53        .parse()
54        .map_err(|_| ProxyError::Parse(format!("Invalid numeric value in duration: '{}'", s)))?;
55
56    Ok(value * multiplier)
57}
58
59impl Config {
60    pub fn from_file(path: &str) -> Result<Self, ProxyError> {
61        let content = std::fs::read_to_string(path)?;
62        content.parse()
63    }
64}
65
66impl FromStr for Config {
67    type Err = ProxyError;
68
69    fn from_str(content: &str) -> Result<Self, Self::Err> {
70        let mut sites = HashMap::new();
71        let mut current_site_address: Option<String> = None;
72        let mut current_site_tls: Option<crate::config::TlsConfig> = None;
73
74        let mut directive_stack: Vec<Vec<Directive>> = vec![vec![]];
75        let mut block_stack: Vec<PendingBlock> = vec![];
76
77        for (line_num, raw_line) in content.lines().enumerate() {
78            let line = raw_line.trim();
79            if line.is_empty() || line.starts_with('#') {
80                continue;
81            }
82
83            // 1. Handle opening brace
84            if line.ends_with('{') {
85                let parts: Vec<&str> = line.split_whitespace().collect();
86                if parts.is_empty() {
87                    continue;
88                }
89
90                // Top-level site block
91                if directive_stack.len() == 1 && current_site_address.is_none() {
92                    current_site_address = Some(parts[0].to_string());
93                    continue;
94                }
95
96                // Nested block (handle_path, method, reverse_proxy, etc.)
97                let directive_type = parts[0].to_string();
98                // Filter out trailing "{" from args
99                let args = parts[1..]
100                    .iter()
101                    .filter(|s| **s != "{")
102                    .map(|s| s.to_string())
103                    .collect();
104
105                block_stack.push(PendingBlock {
106                    directive_type,
107                    args,
108                    connect_timeout: None,
109                    read_timeout: None,
110                });
111                directive_stack.push(vec![]);
112                continue;
113            }
114
115            // 2. Handle closing brace
116            if line == "}" {
117                if directive_stack.len() > 1 {
118                    let finished_directives = directive_stack
119                        .pop()
120                        .expect("directive_stack has at least 2 elements");
121                    let block_info = block_stack.pop().expect("block_stack has matching entry");
122
123                    let completed_directive = match block_info.directive_type.as_str() {
124                        "handle_path" => {
125                            let pattern = block_info.args.first().cloned().unwrap_or_default();
126                            Directive::HandlePath {
127                                pattern,
128                                directives: finished_directives,
129                            }
130                        }
131                        "method" => Directive::Method {
132                            methods: block_info.args,
133                            directives: finished_directives,
134                        },
135                        "reverse_proxy" => {
136                            let to = block_info.args.first().cloned().unwrap_or_default();
137                            Directive::ReverseProxy {
138                                to,
139                                connect_timeout: block_info.connect_timeout,
140                                read_timeout: block_info.read_timeout,
141                            }
142                        }
143                        _ => {
144                            return Err(ProxyError::Parse(format!(
145                                "Unknown block type: {}",
146                                block_info.directive_type
147                            )))
148                        }
149                    };
150
151                    directive_stack
152                        .last_mut()
153                        .expect("directive_stack has parent after pop")
154                        .push(completed_directive);
155                } else {
156                    // Site block closed
157                    if let Some(address) = current_site_address.take() {
158                        let site_directives = directive_stack
159                            .pop()
160                            .expect("site directive_stack is non-empty");
161                        if sites.contains_key(&address) {
162                            return Err(ProxyError::Parse(format!(
163                                "Duplicate site address '{}'. \
164                                 Each address may appear only once in the configuration.",
165                                address
166                            )));
167                        }
168                        sites.insert(
169                            address.clone(),
170                            SiteConfig {
171                                address,
172                                directives: site_directives,
173                                tls: current_site_tls.take(),
174                            },
175                        );
176                        directive_stack.push(vec![]);
177                    }
178                }
179                continue;
180            }
181
182            // 3. Handle simple directives (single line)
183            let parts: Vec<&str> = line.split_whitespace().collect();
184            if parts.is_empty() {
185                continue;
186            }
187
188            let directive_name = parts[0];
189            let args = parts[1..].to_vec();
190
191            // Special handling: timeout settings inside a reverse_proxy block
192            if let Some(block) = block_stack.last_mut() {
193                if block.directive_type == "reverse_proxy" {
194                    match directive_name {
195                        "connect_timeout" => {
196                            let raw = args.first().cloned().ok_or_else(|| {
197                                ProxyError::Parse("Missing value for connect_timeout".to_string())
198                            })?;
199                            block.connect_timeout = Some(parse_duration(raw).map_err(|e| {
200                                ProxyError::Parse(format!(
201                                    "Invalid connect_timeout on line {}: {}",
202                                    line_num + 1,
203                                    e
204                                ))
205                            })?);
206                            continue;
207                        }
208                        "read_timeout" => {
209                            let raw = args.first().cloned().ok_or_else(|| {
210                                ProxyError::Parse("Missing value for read_timeout".to_string())
211                            })?;
212                            block.read_timeout = Some(parse_duration(raw).map_err(|e| {
213                                ProxyError::Parse(format!(
214                                    "Invalid read_timeout on line {}: {}",
215                                    line_num + 1,
216                                    e
217                                ))
218                            })?);
219                            continue;
220                        }
221                        _ => {
222                            return Err(ProxyError::Parse(format!(
223                                "Unexpected directive '{}' inside reverse_proxy block on line {}. Only connect_timeout and read_timeout are allowed.",
224                                directive_name, line_num + 1
225                            )));
226                        }
227                    }
228                }
229            }
230
231            // Special handling: tls directive at site level
232            if directive_name == "tls" && block_stack.is_empty() {
233                let cert_path = args.first().cloned().ok_or_else(|| {
234                    ProxyError::Parse(format!(
235                        "Missing cert path for tls directive on line {}",
236                        line_num + 1
237                    ))
238                })?;
239                let key_path = args.get(1).cloned().ok_or_else(|| {
240                    ProxyError::Parse(format!(
241                        "Missing key path for tls directive on line {}",
242                        line_num + 1
243                    ))
244                })?;
245                if current_site_tls.is_some() {
246                    return Err(ProxyError::Parse(format!(
247                        "Duplicate tls directive on line {}. Only one tls per site is allowed.",
248                        line_num + 1
249                    )));
250                }
251                current_site_tls = Some(crate::config::TlsConfig {
252                    cert_path: cert_path.to_string(),
253                    key_path: key_path.to_string(),
254                });
255                continue;
256            }
257
258            // Regular directive parsing
259            let directive = match directive_name {
260                "reverse_proxy" => {
261                    let to = args.first().cloned().ok_or_else(|| {
262                        ProxyError::Parse("Missing backend URL for reverse_proxy".to_string())
263                    })?;
264                    Directive::ReverseProxy {
265                        to: to.to_string(),
266                        connect_timeout: None,
267                        read_timeout: None,
268                    }
269                }
270                "uri_replace" => {
271                    let find = args.first().cloned().ok_or_else(|| {
272                        ProxyError::Parse("Missing 'find' arg for uri_replace".to_string())
273                    })?;
274                    let replace = args.get(1).cloned().ok_or_else(|| {
275                        ProxyError::Parse("Missing 'replace' arg for uri_replace".to_string())
276                    })?;
277                    Directive::UriReplace {
278                        find: find.to_string(),
279                        replace: replace.to_string(),
280                    }
281                }
282                "header" => {
283                    let raw_name = args.first().cloned().ok_or_else(|| {
284                        ProxyError::Parse("Missing 'name' arg for header".to_string())
285                    })?;
286                    if let Some(name) = raw_name.strip_prefix('-') {
287                        if name.is_empty() {
288                            return Err(ProxyError::Parse(
289                                "Missing header name after '-' for header removal".to_string(),
290                            ));
291                        }
292                        Directive::Header {
293                            name: name.to_string(),
294                            value: None,
295                        }
296                    } else {
297                        let value = args.get(1).cloned().ok_or_else(|| {
298                            ProxyError::Parse("Missing 'value' arg for header".to_string())
299                        })?;
300                        Directive::Header {
301                            name: raw_name.to_string(),
302                            value: Some(value.to_string()),
303                        }
304                    }
305                }
306                "respond" => {
307                    let status = args.first().and_then(|s| s.parse().ok()).ok_or_else(|| {
308                        ProxyError::Parse("Invalid status for respond".to_string())
309                    })?;
310                    let body = args.get(1).cloned().unwrap_or_default();
311                    Directive::Respond {
312                        status,
313                        body: body.to_string(),
314                    }
315                }
316                "strip_prefix" => {
317                    let prefix = args.first().cloned().ok_or_else(|| {
318                        ProxyError::Parse("Missing 'prefix' arg for strip_prefix".to_string())
319                    })?;
320                    Directive::StripPrefix {
321                        prefix: prefix.to_string(),
322                    }
323                }
324                "redirect" => {
325                    let (status, url) = if args.len() >= 2 {
326                        let status: u16 = args[0].parse().map_err(|_| {
327                            ProxyError::Parse(format!(
328                                "Invalid status code for redirect: {}",
329                                args[0]
330                            ))
331                        })?;
332                        let url = args[1..].join(" ");
333                        (status, url)
334                    } else {
335                        let url = args.first().cloned().ok_or_else(|| {
336                            ProxyError::Parse("Missing 'url' arg for redirect".to_string())
337                        })?;
338                        (301u16, url.to_string())
339                    };
340                    Directive::Redirect {
341                        status,
342                        url: url.to_string(),
343                    }
344                }
345                _ => {
346                    return Err(ProxyError::Parse(format!(
347                        "Unknown directive '{}' on line {}",
348                        directive_name,
349                        line_num + 1
350                    )))
351                }
352            };
353
354            directive_stack
355                .last_mut()
356                .expect("directive_stack is non-empty")
357                .push(directive);
358        }
359
360        validate_listen_sockets(&sites)?;
361
362        Ok(Config { sites })
363    }
364}
365
366/// Validate TLS/plain consistency and unique SNI hostnames per listen socket.
367fn validate_listen_sockets(sites: &HashMap<String, SiteConfig>) -> Result<(), ProxyError> {
368    let mut socket_tls: HashMap<SocketAddr, bool> = HashMap::new();
369    let mut socket_sni: HashMap<SocketAddr, HashMap<String, String>> = HashMap::new();
370
371    for site in sites.values() {
372        let listen_addr =
373            resolve_listen_addr(&site.address).map_err(|e| ProxyError::Parse(e.to_string()))?;
374        let is_tls = site.tls.is_some();
375
376        if let Some(&prev_tls) = socket_tls.get(&listen_addr) {
377            if prev_tls != is_tls {
378                return Err(ProxyError::Parse(format!(
379                    "Mixed TLS and non-TLS sites on the same listen address {} is not supported. \
380                     Site '{}' is {} but conflicts with another site on this socket.",
381                    listen_addr,
382                    site.address,
383                    if is_tls { "TLS" } else { "plain HTTP" }
384                )));
385            }
386        } else {
387            socket_tls.insert(listen_addr, is_tls);
388        }
389
390        if is_tls {
391            let sni = extract_hostname(&site.address).to_ascii_lowercase();
392            let sni_map = socket_sni.entry(listen_addr).or_default();
393            if let Some(existing) = sni_map.get(&sni) {
394                return Err(ProxyError::Parse(format!(
395                    "Duplicate SNI hostname '{}' on listen address {} (sites '{}' and '{}')",
396                    sni, listen_addr, existing, site.address
397                )));
398            }
399            sni_map.insert(sni, site.address.clone());
400        }
401    }
402
403    Ok(())
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_parse_duration_seconds() {
412        assert_eq!(parse_duration("30").unwrap(), 30);
413        assert_eq!(parse_duration("30s").unwrap(), 30);
414    }
415
416    #[test]
417    fn test_parse_duration_minutes() {
418        assert_eq!(parse_duration("5m").unwrap(), 300);
419    }
420
421    #[test]
422    fn test_parse_duration_hours() {
423        assert_eq!(parse_duration("2h").unwrap(), 7200);
424    }
425
426    #[test]
427    fn test_parse_duration_days() {
428        assert_eq!(parse_duration("1d").unwrap(), 86400);
429    }
430
431    #[test]
432    fn test_parse_duration_invalid() {
433        assert!(parse_duration("").is_err());
434        assert!(parse_duration("abc").is_err());
435        assert!(parse_duration("10x").is_err());
436    }
437
438    #[test]
439    fn test_parse_reverse_proxy_simple() {
440        let config = "localhost:8080 {\n    reverse_proxy http://backend:9001\n}";
441        let result: Config = config.parse().unwrap();
442        let site = result.sites.get("localhost:8080").unwrap();
443
444        assert_eq!(site.directives.len(), 1);
445        match &site.directives[0] {
446            Directive::ReverseProxy {
447                to,
448                connect_timeout,
449                read_timeout,
450            } => {
451                assert_eq!(to, "http://backend:9001");
452                assert_eq!(*connect_timeout, None);
453                assert_eq!(*read_timeout, None);
454            }
455            _ => panic!("Expected ReverseProxy directive"),
456        }
457    }
458
459    #[test]
460    fn test_parse_reverse_proxy_with_timeouts() {
461        let config = r#"localhost:8080 {
462    reverse_proxy http://backend:9001 {
463        connect_timeout 10s
464        read_timeout 5m
465    }
466}"#;
467        let result: Config = config.parse().unwrap();
468        let site = result.sites.get("localhost:8080").unwrap();
469
470        assert_eq!(site.directives.len(), 1);
471        match &site.directives[0] {
472            Directive::ReverseProxy {
473                to,
474                connect_timeout,
475                read_timeout,
476            } => {
477                assert_eq!(to, "http://backend:9001");
478                assert_eq!(*connect_timeout, Some(10));
479                assert_eq!(*read_timeout, Some(300));
480            }
481            _ => panic!("Expected ReverseProxy directive"),
482        }
483    }
484
485    #[test]
486    fn test_parse_reverse_proxy_with_connect_timeout_only() {
487        let config = r#"localhost:8080 {
488    reverse_proxy http://backend:9001 {
489        connect_timeout 5s
490    }
491}"#;
492        let result: Config = config.parse().unwrap();
493        let site = result.sites.get("localhost:8080").unwrap();
494
495        match &site.directives[0] {
496            Directive::ReverseProxy {
497                connect_timeout,
498                read_timeout,
499                ..
500            } => {
501                assert_eq!(*connect_timeout, Some(5));
502                assert_eq!(*read_timeout, None);
503            }
504            _ => panic!("Expected ReverseProxy directive"),
505        }
506    }
507
508    #[test]
509    fn test_parse_reverse_proxy_block_rejects_unknown_directive() {
510        let config = r#"localhost:8080 {
511    reverse_proxy http://backend:9001 {
512        unknown_setting 42
513    }
514}"#;
515        let result: Result<Config, _> = config.parse();
516        assert!(result.is_err());
517        let err_msg = format!("{}", result.unwrap_err());
518        assert!(err_msg.contains("Unexpected directive"), "{}", err_msg);
519    }
520
521    #[test]
522    fn test_parse_tls_directive() {
523        let config = r#"example.com:443 {
524    tls /etc/ssl/cert.pem /etc/ssl/key.pem
525    reverse_proxy backend:8080
526}"#;
527        let result: Config = config.parse().unwrap();
528        let site = result.sites.get("example.com:443").unwrap();
529
530        assert!(site.tls.is_some());
531        let tls = site.tls.as_ref().unwrap();
532        assert_eq!(tls.cert_path, "/etc/ssl/cert.pem");
533        assert_eq!(tls.key_path, "/etc/ssl/key.pem");
534
535        assert_eq!(site.directives.len(), 1);
536        match &site.directives[0] {
537            Directive::ReverseProxy { to, .. } => {
538                assert_eq!(to, "backend:8080");
539            }
540            _ => panic!("Expected ReverseProxy directive"),
541        }
542    }
543
544    #[test]
545    fn test_parse_tls_missing_cert_path() {
546        let config = "example.com:443 {\n    tls\n}";
547        let result: Result<Config, _> = config.parse();
548        assert!(result.is_err());
549        let err_msg = format!("{}", result.unwrap_err());
550        assert!(err_msg.contains("Missing cert path"), "{}", err_msg);
551    }
552
553    #[test]
554    fn test_parse_tls_missing_key_path() {
555        let config = "example.com:443 {\n    tls /etc/ssl/cert.pem\n}";
556        let result: Result<Config, _> = config.parse();
557        assert!(result.is_err());
558        let err_msg = format!("{}", result.unwrap_err());
559        assert!(err_msg.contains("Missing key path"), "{}", err_msg);
560    }
561
562    #[test]
563    fn test_parse_tls_duplicate_rejected() {
564        let config = r#"example.com:443 {
565    tls /a/cert.pem /a/key.pem
566    tls /b/cert.pem /b/key.pem
567    reverse_proxy backend:8080
568}"#;
569        let result: Result<Config, _> = config.parse();
570        assert!(result.is_err());
571        let err_msg = format!("{}", result.unwrap_err());
572        assert!(err_msg.contains("Duplicate tls"), "{}", err_msg);
573    }
574
575    #[test]
576    fn test_parse_mixed_tls_and_non_tls_sites() {
577        let config = r#"localhost:8080 {
578    reverse_proxy backend:3000
579}
580example.com:443 {
581    tls /etc/ssl/cert.pem /etc/ssl/key.pem
582    reverse_proxy backend:8080
583}"#;
584        let result: Config = config.parse().unwrap();
585
586        // HTTP site
587        let http_site = result.sites.get("localhost:8080").unwrap();
588        assert!(http_site.tls.is_none());
589
590        // HTTPS site
591        let https_site = result.sites.get("example.com:443").unwrap();
592        assert!(https_site.tls.is_some());
593    }
594
595    #[test]
596    fn test_parse_duplicate_address_rejected() {
597        let config = r#"example.com:443 {
598    tls /a/cert.pem /a/key.pem
599    reverse_proxy backend:8080
600}
601example.com:443 {
602    reverse_proxy backend:9000
603}"#;
604        let result: Result<Config, _> = config.parse();
605        assert!(result.is_err());
606        let err_msg = format!("{}", result.unwrap_err());
607        assert!(
608            err_msg.contains("Duplicate site address"),
609            "Expected 'Duplicate site address' error, got: {}",
610            err_msg
611        );
612    }
613
614    #[test]
615    fn test_parse_mixed_tls_on_same_listen_socket_rejected() {
616        let config = r#"example.com:443 {
617    tls /etc/ssl/cert.pem /etc/ssl/key.pem
618    reverse_proxy backend:8080
619}
6200.0.0.0:443 {
621    reverse_proxy backend:3000
622}"#;
623        let result: Result<Config, _> = config.parse();
624        assert!(result.is_err());
625        let err_msg = format!("{}", result.unwrap_err());
626        assert!(
627            err_msg.contains("Mixed TLS and non-TLS"),
628            "got: {}",
629            err_msg
630        );
631    }
632
633    #[test]
634    fn test_parse_duplicate_sni_on_same_listen_socket_rejected() {
635        let config = r#"Example.com:8443 {
636    tls /a/cert.pem /a/key.pem
637    respond 200 "A"
638}
639example.com:8443 {
640    tls /b/cert.pem /b/key.pem
641    respond 200 "B"
642}"#;
643        let result: Result<Config, _> = config.parse();
644        assert!(result.is_err());
645        let err_msg = format!("{}", result.unwrap_err());
646        assert!(
647            err_msg.contains("Duplicate SNI hostname"),
648            "got: {}",
649            err_msg
650        );
651    }
652}