Skip to main content

tiny_proxy/config/
parser.rs

1use crate::config::address::{extract_hostname, resolve_listen_addr};
2use crate::config::models::HeaderDirective;
3use crate::config::{Config, Directive, SiteConfig};
4use crate::error::ProxyError;
5use std::collections::HashMap;
6use std::net::SocketAddr;
7use std::str::FromStr;
8
9#[derive(Debug)]
10struct PendingBlock {
11    directive_type: String,
12    args: Vec<String>,
13    // Timeout settings for reverse_proxy blocks (in seconds)
14    connect_timeout: Option<u64>,
15    read_timeout: Option<u64>,
16    // header_up operations collected inside a reverse_proxy block
17    header_up: Vec<HeaderDirective>,
18}
19
20/// Parse a human-readable duration string into seconds.
21///
22/// Supported formats:
23/// - Plain number: `"30"` → 30 seconds
24/// - Seconds: `"30s"` → 30
25/// - Minutes: `"5m"` → 300
26/// - Hours: `"2h"` → 7200
27/// - Days: `"1d"` → 86400
28fn parse_duration(s: &str) -> Result<u64, ProxyError> {
29    let s = s.trim();
30    if s.is_empty() {
31        return Err(ProxyError::Parse("Empty duration value".to_string()));
32    }
33
34    // Try plain number first (seconds)
35    if let Ok(secs) = s.parse::<u64>() {
36        return Ok(secs);
37    }
38
39    // Parse with suffix
40    let (num_part, multiplier) = if let Some(n) = s.strip_suffix('s') {
41        (n, 1u64)
42    } else if let Some(n) = s.strip_suffix('m') {
43        (n, 60u64)
44    } else if let Some(n) = s.strip_suffix('h') {
45        (n, 3600u64)
46    } else if let Some(n) = s.strip_suffix('d') {
47        (n, 86400u64)
48    } else {
49        return Err(ProxyError::Parse(format!(
50            "Invalid duration '{}'. Use a plain number or Ns/Nm/Nh/Nd",
51            s
52        )));
53    };
54
55    let value: u64 = num_part
56        .parse()
57        .map_err(|_| ProxyError::Parse(format!("Invalid numeric value in duration: '{}'", s)))?;
58
59    Ok(value * multiplier)
60}
61
62impl Config {
63    pub fn from_file(path: &str) -> Result<Self, ProxyError> {
64        let content = std::fs::read_to_string(path)?;
65        content.parse()
66    }
67}
68
69impl FromStr for Config {
70    type Err = ProxyError;
71
72    fn from_str(content: &str) -> Result<Self, Self::Err> {
73        let mut sites = HashMap::new();
74        let mut current_site_address: Option<String> = None;
75        let mut current_site_tls: Option<crate::config::TlsConfig> = None;
76
77        let mut directive_stack: Vec<Vec<Directive>> = vec![vec![]];
78        let mut block_stack: Vec<PendingBlock> = vec![];
79
80        for (line_num, raw_line) in content.lines().enumerate() {
81            let line = raw_line.trim();
82            if line.is_empty() || line.starts_with('#') {
83                continue;
84            }
85
86            // 1. Handle opening brace
87            if line.ends_with('{') {
88                let parts: Vec<&str> = line.split_whitespace().collect();
89                if parts.is_empty() {
90                    continue;
91                }
92
93                // Top-level site block
94                if directive_stack.len() == 1 && current_site_address.is_none() {
95                    current_site_address = Some(parts[0].to_string());
96                    continue;
97                }
98
99                // Nested block (handle_path, method, reverse_proxy, etc.)
100                let directive_type = parts[0].to_string();
101                // Filter out trailing "{" from args
102                let args = parts[1..]
103                    .iter()
104                    .filter(|s| **s != "{")
105                    .map(|s| s.to_string())
106                    .collect();
107
108                block_stack.push(PendingBlock {
109                    directive_type,
110                    args,
111                    connect_timeout: None,
112                    read_timeout: None,
113                    header_up: vec![],
114                });
115                directive_stack.push(vec![]);
116                continue;
117            }
118
119            // 2. Handle closing brace
120            if line == "}" {
121                if directive_stack.len() > 1 {
122                    let finished_directives = directive_stack
123                        .pop()
124                        .expect("directive_stack has at least 2 elements");
125                    let block_info = block_stack.pop().expect("block_stack has matching entry");
126
127                    let completed_directive = match block_info.directive_type.as_str() {
128                        "handle_path" => {
129                            let pattern = block_info.args.first().cloned().unwrap_or_default();
130                            Directive::HandlePath {
131                                pattern,
132                                directives: finished_directives,
133                            }
134                        }
135                        "method" => Directive::Method {
136                            methods: block_info.args,
137                            directives: finished_directives,
138                        },
139                        "reverse_proxy" => {
140                            let to = block_info.args.first().cloned().unwrap_or_default();
141                            Directive::ReverseProxy {
142                                to,
143                                connect_timeout: block_info.connect_timeout,
144                                read_timeout: block_info.read_timeout,
145                                header_up: block_info.header_up,
146                            }
147                        }
148                        _ => {
149                            return Err(ProxyError::Parse(format!(
150                                "Unknown block type: {}",
151                                block_info.directive_type
152                            )))
153                        }
154                    };
155
156                    directive_stack
157                        .last_mut()
158                        .expect("directive_stack has parent after pop")
159                        .push(completed_directive);
160                } else {
161                    // Site block closed
162                    if let Some(address) = current_site_address.take() {
163                        let site_directives = directive_stack
164                            .pop()
165                            .expect("site directive_stack is non-empty");
166                        if sites.contains_key(&address) {
167                            return Err(ProxyError::Parse(format!(
168                                "Duplicate site address '{}'. \
169                                 Each address may appear only once in the configuration.",
170                                address
171                            )));
172                        }
173                        sites.insert(
174                            address.clone(),
175                            SiteConfig {
176                                address,
177                                directives: site_directives,
178                                tls: current_site_tls.take(),
179                            },
180                        );
181                        directive_stack.push(vec![]);
182                    }
183                }
184                continue;
185            }
186
187            // 3. Handle simple directives (single line)
188            let parts: Vec<&str> = line.split_whitespace().collect();
189            if parts.is_empty() {
190                continue;
191            }
192
193            let directive_name = parts[0];
194            let args = parts[1..].to_vec();
195
196            // Special handling: timeout and header_up settings inside a reverse_proxy block
197            if let Some(block) = block_stack.last_mut() {
198                if block.directive_type == "reverse_proxy" {
199                    match directive_name {
200                        "connect_timeout" => {
201                            let raw = args.first().cloned().ok_or_else(|| {
202                                ProxyError::Parse("Missing value for connect_timeout".to_string())
203                            })?;
204                            block.connect_timeout = Some(parse_duration(raw).map_err(|e| {
205                                ProxyError::Parse(format!(
206                                    "Invalid connect_timeout on line {}: {}",
207                                    line_num + 1,
208                                    e
209                                ))
210                            })?);
211                            continue;
212                        }
213                        "read_timeout" => {
214                            let raw = args.first().cloned().ok_or_else(|| {
215                                ProxyError::Parse("Missing value for read_timeout".to_string())
216                            })?;
217                            block.read_timeout = Some(parse_duration(raw).map_err(|e| {
218                                ProxyError::Parse(format!(
219                                    "Invalid read_timeout on line {}: {}",
220                                    line_num + 1,
221                                    e
222                                ))
223                            })?);
224                            continue;
225                        }
226                        "header_up" => {
227                            let raw_name = args.first().cloned().ok_or_else(|| {
228                                ProxyError::Parse(format!(
229                                    "Missing header name for header_up on line {}",
230                                    line_num + 1
231                                ))
232                            })?;
233                            let directive = if let Some(name) = raw_name.strip_prefix('-') {
234                                if name.is_empty() {
235                                    return Err(ProxyError::Parse(format!(
236                                        "Missing header name after '-' in header_up on line {}",
237                                        line_num + 1
238                                    )));
239                                }
240                                HeaderDirective {
241                                    name: name.to_string(),
242                                    value: None,
243                                }
244                            } else {
245                                let value = args[1..].join(" ");
246                                if value.is_empty() {
247                                    return Err(ProxyError::Parse(format!(
248                                        "Missing value for header_up {} on line {}",
249                                        raw_name,
250                                        line_num + 1
251                                    )));
252                                }
253                                HeaderDirective {
254                                    name: raw_name.to_string(),
255                                    value: Some(value),
256                                }
257                            };
258                            block.header_up.push(directive);
259                            continue;
260                        }
261                        _ => {
262                            return Err(ProxyError::Parse(format!(
263                                "Unexpected directive '{}' inside reverse_proxy block on line {}. Allowed: connect_timeout, read_timeout, header_up.",
264                                directive_name, line_num + 1
265                            )));
266                        }
267                    }
268                }
269            }
270
271            // Special handling: tls directive at site level
272            if directive_name == "tls" && block_stack.is_empty() {
273                let cert_path = args.first().cloned().ok_or_else(|| {
274                    ProxyError::Parse(format!(
275                        "Missing cert path for tls directive on line {}",
276                        line_num + 1
277                    ))
278                })?;
279                let key_path = args.get(1).cloned().ok_or_else(|| {
280                    ProxyError::Parse(format!(
281                        "Missing key path for tls directive on line {}",
282                        line_num + 1
283                    ))
284                })?;
285                if current_site_tls.is_some() {
286                    return Err(ProxyError::Parse(format!(
287                        "Duplicate tls directive on line {}. Only one tls per site is allowed.",
288                        line_num + 1
289                    )));
290                }
291                current_site_tls = Some(crate::config::TlsConfig {
292                    cert_path: cert_path.to_string(),
293                    key_path: key_path.to_string(),
294                });
295                continue;
296            }
297
298            // Regular directive parsing
299            let directive = match directive_name {
300                "reverse_proxy" => {
301                    let to = args.first().cloned().ok_or_else(|| {
302                        ProxyError::Parse("Missing backend URL for reverse_proxy".to_string())
303                    })?;
304                    Directive::ReverseProxy {
305                        to: to.to_string(),
306                        connect_timeout: None,
307                        read_timeout: None,
308                        header_up: vec![],
309                    }
310                }
311                "uri_replace" => {
312                    let find = args.first().cloned().ok_or_else(|| {
313                        ProxyError::Parse("Missing 'find' arg for uri_replace".to_string())
314                    })?;
315                    let replace = args.get(1).cloned().ok_or_else(|| {
316                        ProxyError::Parse("Missing 'replace' arg for uri_replace".to_string())
317                    })?;
318                    Directive::UriReplace {
319                        find: find.to_string(),
320                        replace: replace.to_string(),
321                    }
322                }
323                "header" => {
324                    let raw_name = args.first().cloned().ok_or_else(|| {
325                        ProxyError::Parse("Missing 'name' arg for header".to_string())
326                    })?;
327                    if let Some(name) = raw_name.strip_prefix('-') {
328                        if name.is_empty() {
329                            return Err(ProxyError::Parse(
330                                "Missing header name after '-' for header removal".to_string(),
331                            ));
332                        }
333                        Directive::Header {
334                            name: name.to_string(),
335                            value: None,
336                        }
337                    } else {
338                        let value = args.get(1).cloned().ok_or_else(|| {
339                            ProxyError::Parse("Missing 'value' arg for header".to_string())
340                        })?;
341                        Directive::Header {
342                            name: raw_name.to_string(),
343                            value: Some(value.to_string()),
344                        }
345                    }
346                }
347                "respond" => {
348                    let status = args.first().and_then(|s| s.parse().ok()).ok_or_else(|| {
349                        ProxyError::Parse("Invalid status for respond".to_string())
350                    })?;
351                    let body = args.get(1).cloned().unwrap_or_default();
352                    Directive::Respond {
353                        status,
354                        body: body.to_string(),
355                    }
356                }
357                "strip_prefix" => {
358                    let prefix = args.first().cloned().ok_or_else(|| {
359                        ProxyError::Parse("Missing 'prefix' arg for strip_prefix".to_string())
360                    })?;
361                    Directive::StripPrefix {
362                        prefix: prefix.to_string(),
363                    }
364                }
365                "redirect" => {
366                    let (status, url) = if args.len() >= 2 {
367                        let status: u16 = args[0].parse().map_err(|_| {
368                            ProxyError::Parse(format!(
369                                "Invalid status code for redirect: {}",
370                                args[0]
371                            ))
372                        })?;
373                        let url = args[1..].join(" ");
374                        (status, url)
375                    } else {
376                        let url = args.first().cloned().ok_or_else(|| {
377                            ProxyError::Parse("Missing 'url' arg for redirect".to_string())
378                        })?;
379                        (301u16, url.to_string())
380                    };
381                    Directive::Redirect {
382                        status,
383                        url: url.to_string(),
384                    }
385                }
386                _ => {
387                    return Err(ProxyError::Parse(format!(
388                        "Unknown directive '{}' on line {}",
389                        directive_name,
390                        line_num + 1
391                    )))
392                }
393            };
394
395            directive_stack
396                .last_mut()
397                .expect("directive_stack is non-empty")
398                .push(directive);
399        }
400
401        validate_listen_sockets(&sites)?;
402
403        Ok(Config { sites })
404    }
405}
406
407/// Validate TLS/plain consistency and unique SNI hostnames per listen socket.
408fn validate_listen_sockets(sites: &HashMap<String, SiteConfig>) -> Result<(), ProxyError> {
409    let mut socket_tls: HashMap<SocketAddr, bool> = HashMap::new();
410    let mut socket_sni: HashMap<SocketAddr, HashMap<String, String>> = HashMap::new();
411
412    for site in sites.values() {
413        let listen_addr =
414            resolve_listen_addr(&site.address).map_err(|e| ProxyError::Parse(e.to_string()))?;
415        let is_tls = site.tls.is_some();
416
417        if let Some(&prev_tls) = socket_tls.get(&listen_addr) {
418            if prev_tls != is_tls {
419                return Err(ProxyError::Parse(format!(
420                    "Mixed TLS and non-TLS sites on the same listen address {} is not supported. \
421                     Site '{}' is {} but conflicts with another site on this socket.",
422                    listen_addr,
423                    site.address,
424                    if is_tls { "TLS" } else { "plain HTTP" }
425                )));
426            }
427        } else {
428            socket_tls.insert(listen_addr, is_tls);
429        }
430
431        if is_tls {
432            let sni = extract_hostname(&site.address).to_ascii_lowercase();
433            let sni_map = socket_sni.entry(listen_addr).or_default();
434            if let Some(existing) = sni_map.get(&sni) {
435                return Err(ProxyError::Parse(format!(
436                    "Duplicate SNI hostname '{}' on listen address {} (sites '{}' and '{}')",
437                    sni, listen_addr, existing, site.address
438                )));
439            }
440            sni_map.insert(sni, site.address.clone());
441        }
442    }
443
444    Ok(())
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_parse_duration_seconds() {
453        assert_eq!(parse_duration("30").unwrap(), 30);
454        assert_eq!(parse_duration("30s").unwrap(), 30);
455    }
456
457    #[test]
458    fn test_parse_duration_minutes() {
459        assert_eq!(parse_duration("5m").unwrap(), 300);
460    }
461
462    #[test]
463    fn test_parse_duration_hours() {
464        assert_eq!(parse_duration("2h").unwrap(), 7200);
465    }
466
467    #[test]
468    fn test_parse_duration_days() {
469        assert_eq!(parse_duration("1d").unwrap(), 86400);
470    }
471
472    #[test]
473    fn test_parse_duration_invalid() {
474        assert!(parse_duration("").is_err());
475        assert!(parse_duration("abc").is_err());
476        assert!(parse_duration("10x").is_err());
477    }
478
479    #[test]
480    fn test_parse_reverse_proxy_simple() {
481        let config = "localhost:8080 {\n    reverse_proxy http://backend:9001\n}";
482        let result: Config = config.parse().unwrap();
483        let site = result.sites.get("localhost:8080").unwrap();
484
485        assert_eq!(site.directives.len(), 1);
486        match &site.directives[0] {
487            Directive::ReverseProxy {
488                to,
489                connect_timeout,
490                read_timeout,
491                ..
492            } => {
493                assert_eq!(to, "http://backend:9001");
494                assert_eq!(*connect_timeout, None);
495                assert_eq!(*read_timeout, None);
496            }
497            _ => panic!("Expected ReverseProxy directive"),
498        }
499    }
500
501    #[test]
502    fn test_parse_reverse_proxy_with_timeouts() {
503        let config = r#"localhost:8080 {
504    reverse_proxy http://backend:9001 {
505        connect_timeout 10s
506        read_timeout 5m
507    }
508}"#;
509        let result: Config = config.parse().unwrap();
510        let site = result.sites.get("localhost:8080").unwrap();
511
512        assert_eq!(site.directives.len(), 1);
513        match &site.directives[0] {
514            Directive::ReverseProxy {
515                to,
516                connect_timeout,
517                read_timeout,
518                ..
519            } => {
520                assert_eq!(to, "http://backend:9001");
521                assert_eq!(*connect_timeout, Some(10));
522                assert_eq!(*read_timeout, Some(300));
523            }
524            _ => panic!("Expected ReverseProxy directive"),
525        }
526    }
527
528    #[test]
529    fn test_parse_reverse_proxy_with_connect_timeout_only() {
530        let config = r#"localhost:8080 {
531    reverse_proxy http://backend:9001 {
532        connect_timeout 5s
533    }
534}"#;
535        let result: Config = config.parse().unwrap();
536        let site = result.sites.get("localhost:8080").unwrap();
537
538        match &site.directives[0] {
539            Directive::ReverseProxy {
540                connect_timeout,
541                read_timeout,
542                ..
543            } => {
544                assert_eq!(*connect_timeout, Some(5));
545                assert_eq!(*read_timeout, None);
546            }
547            _ => panic!("Expected ReverseProxy directive"),
548        }
549    }
550
551    #[test]
552    fn test_parse_reverse_proxy_with_header_up() {
553        let config = r#"localhost:8080 {
554    reverse_proxy https://api.example.com:443 {
555        header_up Host {upstream_host}
556        header_up X-Original-Uri {request.uri}
557        header_up -Accept-Encoding
558    }
559}"#;
560        let result: Config = config.parse().unwrap();
561        let site = result.sites.get("localhost:8080").unwrap();
562
563        match &site.directives[0] {
564            Directive::ReverseProxy { to, header_up, .. } => {
565                assert_eq!(to, "https://api.example.com:443");
566                assert_eq!(header_up.len(), 3);
567                assert_eq!(header_up[0].name, "Host");
568                assert_eq!(header_up[0].value.as_deref(), Some("{upstream_host}"));
569                assert_eq!(header_up[1].name, "X-Original-Uri");
570                assert_eq!(header_up[1].value.as_deref(), Some("{request.uri}"));
571                assert_eq!(header_up[2].name, "Accept-Encoding");
572                assert!(header_up[2].value.is_none());
573            }
574            _ => panic!("Expected ReverseProxy directive"),
575        }
576    }
577
578    #[test]
579    fn test_parse_reverse_proxy_block_rejects_unknown_directive() {
580        let config = r#"localhost:8080 {
581    reverse_proxy http://backend:9001 {
582        unknown_setting 42
583    }
584}"#;
585        let result: Result<Config, _> = config.parse();
586        assert!(result.is_err());
587        let err_msg = format!("{}", result.unwrap_err());
588        assert!(err_msg.contains("Unexpected directive"), "{}", err_msg);
589    }
590
591    #[test]
592    fn test_parse_tls_directive() {
593        let config = r#"example.com:443 {
594    tls /etc/ssl/cert.pem /etc/ssl/key.pem
595    reverse_proxy backend:8080
596}"#;
597        let result: Config = config.parse().unwrap();
598        let site = result.sites.get("example.com:443").unwrap();
599
600        assert!(site.tls.is_some());
601        let tls = site.tls.as_ref().unwrap();
602        assert_eq!(tls.cert_path, "/etc/ssl/cert.pem");
603        assert_eq!(tls.key_path, "/etc/ssl/key.pem");
604
605        assert_eq!(site.directives.len(), 1);
606        match &site.directives[0] {
607            Directive::ReverseProxy { to, .. } => {
608                assert_eq!(to, "backend:8080");
609            }
610            _ => panic!("Expected ReverseProxy directive"),
611        }
612    }
613
614    #[test]
615    fn test_parse_tls_missing_cert_path() {
616        let config = "example.com:443 {\n    tls\n}";
617        let result: Result<Config, _> = config.parse();
618        assert!(result.is_err());
619        let err_msg = format!("{}", result.unwrap_err());
620        assert!(err_msg.contains("Missing cert path"), "{}", err_msg);
621    }
622
623    #[test]
624    fn test_parse_tls_missing_key_path() {
625        let config = "example.com:443 {\n    tls /etc/ssl/cert.pem\n}";
626        let result: Result<Config, _> = config.parse();
627        assert!(result.is_err());
628        let err_msg = format!("{}", result.unwrap_err());
629        assert!(err_msg.contains("Missing key path"), "{}", err_msg);
630    }
631
632    #[test]
633    fn test_parse_tls_duplicate_rejected() {
634        let config = r#"example.com:443 {
635    tls /a/cert.pem /a/key.pem
636    tls /b/cert.pem /b/key.pem
637    reverse_proxy backend:8080
638}"#;
639        let result: Result<Config, _> = config.parse();
640        assert!(result.is_err());
641        let err_msg = format!("{}", result.unwrap_err());
642        assert!(err_msg.contains("Duplicate tls"), "{}", err_msg);
643    }
644
645    #[test]
646    fn test_parse_mixed_tls_and_non_tls_sites() {
647        let config = r#"localhost:8080 {
648    reverse_proxy backend:3000
649}
650example.com:443 {
651    tls /etc/ssl/cert.pem /etc/ssl/key.pem
652    reverse_proxy backend:8080
653}"#;
654        let result: Config = config.parse().unwrap();
655
656        // HTTP site
657        let http_site = result.sites.get("localhost:8080").unwrap();
658        assert!(http_site.tls.is_none());
659
660        // HTTPS site
661        let https_site = result.sites.get("example.com:443").unwrap();
662        assert!(https_site.tls.is_some());
663    }
664
665    #[test]
666    fn test_parse_duplicate_address_rejected() {
667        let config = r#"example.com:443 {
668    tls /a/cert.pem /a/key.pem
669    reverse_proxy backend:8080
670}
671example.com:443 {
672    reverse_proxy backend:9000
673}"#;
674        let result: Result<Config, _> = config.parse();
675        assert!(result.is_err());
676        let err_msg = format!("{}", result.unwrap_err());
677        assert!(
678            err_msg.contains("Duplicate site address"),
679            "Expected 'Duplicate site address' error, got: {}",
680            err_msg
681        );
682    }
683
684    #[test]
685    fn test_parse_mixed_tls_on_same_listen_socket_rejected() {
686        let config = r#"example.com:443 {
687    tls /etc/ssl/cert.pem /etc/ssl/key.pem
688    reverse_proxy backend:8080
689}
6900.0.0.0:443 {
691    reverse_proxy backend:3000
692}"#;
693        let result: Result<Config, _> = config.parse();
694        assert!(result.is_err());
695        let err_msg = format!("{}", result.unwrap_err());
696        assert!(
697            err_msg.contains("Mixed TLS and non-TLS"),
698            "got: {}",
699            err_msg
700        );
701    }
702
703    #[test]
704    fn test_parse_duplicate_sni_on_same_listen_socket_rejected() {
705        let config = r#"Example.com:8443 {
706    tls /a/cert.pem /a/key.pem
707    respond 200 "A"
708}
709example.com:8443 {
710    tls /b/cert.pem /b/key.pem
711    respond 200 "B"
712}"#;
713        let result: Result<Config, _> = config.parse();
714        assert!(result.is_err());
715        let err_msg = format!("{}", result.unwrap_err());
716        assert!(
717            err_msg.contains("Duplicate SNI hostname"),
718            "got: {}",
719            err_msg
720        );
721    }
722}