warpdrive_proxy/config/
toml.rs

1//! TOML configuration file parsing
2//!
3//! This module provides TOML-based configuration for advanced routing scenarios.
4//! The TOML config unlocks multi-upstream routing, load balancing, and path-based routing.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use crate::router::Protocol;
11
12/// Root TOML configuration structure
13#[derive(Debug, Clone, Deserialize, Serialize)]
14pub struct TomlConfig {
15    /// Server configuration (optional, falls back to env vars)
16    #[serde(default)]
17    pub server: ServerConfig,
18
19    /// Upstream service definitions
20    pub upstreams: HashMap<String, UpstreamConfig>,
21
22    /// Routing rules (evaluated in order)
23    pub routes: Vec<RouteConfig>,
24}
25
26/// Server configuration section
27#[derive(Debug, Clone, Default, Deserialize, Serialize)]
28pub struct ServerConfig {
29    /// HTTP port (default: from env or 8080)
30    pub http_port: Option<u16>,
31
32    /// HTTPS port (default: from env or 443)
33    pub https_port: Option<u16>,
34
35    /// Number of worker threads (default: num_cpus)
36    pub worker_threads: Option<usize>,
37}
38
39/// Upstream service configuration
40#[derive(Debug, Clone, Deserialize, Serialize)]
41pub struct UpstreamConfig {
42    /// Protocol type (http, https, ws, wss, grpc)
43    #[serde(default)]
44    pub protocol: Protocol,
45
46    /// Single backend: hostname
47    pub host: Option<String>,
48
49    /// Single backend: port
50    pub port: Option<u16>,
51
52    /// Single backend: Unix domain socket path
53    pub socket: Option<PathBuf>,
54
55    /// Load balanced pool: multiple backend instances
56    pub backends: Option<Vec<BackendConfig>>,
57
58    /// Load balancing strategy ("round_robin", "random", "least_conn", "ip_hash")
59    #[serde(default = "default_strategy")]
60    pub strategy: String,
61
62    /// SNI hostname for TLS (defaults to host)
63    pub sni: Option<String>,
64
65    /// Process supervision (optional)
66    pub process: Option<ProcessConfig>,
67}
68
69fn default_strategy() -> String {
70    "round_robin".to_string()
71}
72
73/// Backend instance configuration (for load balanced pools)
74#[derive(Debug, Clone, Deserialize, Serialize)]
75pub struct BackendConfig {
76    /// Protocol type (http, https, ws, wss, grpc)
77    #[serde(default)]
78    pub protocol: Protocol,
79
80    /// Hostname
81    pub host: Option<String>,
82
83    /// Port
84    pub port: Option<u16>,
85
86    /// Unix domain socket path
87    pub socket: Option<PathBuf>,
88
89    /// SNI hostname for TLS (defaults to host)
90    pub sni: Option<String>,
91}
92
93/// Route matching and upstream selection
94#[derive(Debug, Clone, Deserialize, Serialize)]
95pub struct RouteConfig {
96    /// Path prefix to match (e.g., "/api")
97    pub path_prefix: Option<String>,
98
99    /// Exact path to match
100    pub path_exact: Option<String>,
101
102    /// Regex pattern to match
103    pub path_regex: Option<String>,
104
105    /// Host header to match
106    pub host: Option<String>,
107
108    /// HTTP methods to match
109    pub methods: Option<Vec<String>>,
110
111    /// Header to match
112    pub header: Option<HeaderMatch>,
113
114    /// Target upstream name
115    pub upstream: String,
116
117    /// Strip the matched prefix before forwarding
118    #[serde(default)]
119    pub strip_prefix: bool,
120
121    /// Rewrite the path
122    pub rewrite: Option<String>,
123
124    /// Headers to add to request
125    #[serde(default)]
126    pub add_headers: Vec<HeaderValue>,
127
128    /// Read timeout in seconds
129    pub read_timeout_secs: Option<u64>,
130
131    /// Connect timeout in seconds
132    pub connect_timeout_secs: Option<u64>,
133
134    /// Description (for logging/debugging)
135    pub description: Option<String>,
136}
137
138/// Header matching configuration
139#[derive(Debug, Clone, Deserialize, Serialize)]
140pub struct HeaderMatch {
141    pub name: String,
142    pub value: String,
143}
144
145/// Header to add to request
146#[derive(Debug, Clone, Deserialize, Serialize)]
147pub struct HeaderValue {
148    pub name: String,
149    pub value: String,
150}
151
152/// Process supervision configuration
153#[derive(Debug, Clone, Deserialize, Serialize)]
154pub struct ProcessConfig {
155    /// Command to execute
156    pub command: String,
157
158    /// Command arguments
159    #[serde(default)]
160    pub args: Vec<String>,
161
162    /// Environment variables
163    #[serde(default)]
164    pub env: HashMap<String, String>,
165}
166
167impl TomlConfig {
168    /// Load configuration from a TOML file
169    pub fn from_file(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
170        let contents = std::fs::read_to_string(path)?;
171        let config: TomlConfig = toml::from_str(&contents)?;
172        Ok(config)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_toml_parsing() {
182        let toml_str = r#"
183            [server]
184            http_port = 8080
185            https_port = 443
186
187            [upstreams.rails]
188            protocol = "http"
189            host = "127.0.0.1"
190            port = 3000
191
192            [upstreams.cable]
193            protocol = "wss"
194            host = "127.0.0.1"
195            port = 3001
196
197            [[routes]]
198            path_prefix = "/cable"
199            upstream = "cable"
200
201            [[routes]]
202            path_prefix = "/"
203            upstream = "rails"
204        "#;
205
206        let config: TomlConfig = toml::from_str(toml_str).unwrap();
207
208        assert_eq!(config.server.http_port, Some(8080));
209        assert_eq!(config.upstreams.len(), 2);
210        assert_eq!(config.routes.len(), 2);
211
212        let rails = config.upstreams.get("rails").unwrap();
213        assert_eq!(rails.protocol, Protocol::Http);
214        assert_eq!(rails.port, Some(3000));
215    }
216
217    #[test]
218    fn test_load_balanced_upstream() {
219        let toml_str = r#"
220            [upstreams.rails]
221            protocol = "http"
222            strategy = "round_robin"
223
224            [[upstreams.rails.backends]]
225            host = "127.0.0.1"
226            port = 3000
227
228            [[upstreams.rails.backends]]
229            host = "127.0.0.1"
230            port = 3001
231
232            [[routes]]
233            path_prefix = "/"
234            upstream = "rails"
235        "#;
236
237        let config: TomlConfig = toml::from_str(toml_str).unwrap();
238        let rails = config.upstreams.get("rails").unwrap();
239
240        assert_eq!(rails.strategy, "round_robin");
241        assert!(rails.backends.is_some());
242        assert_eq!(rails.backends.as_ref().unwrap().len(), 2);
243    }
244}