warpdrive_proxy/config/
toml.rs1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use crate::router::Protocol;
11
12#[derive(Debug, Clone, Deserialize, Serialize)]
14pub struct TomlConfig {
15 #[serde(default)]
17 pub server: ServerConfig,
18
19 pub upstreams: HashMap<String, UpstreamConfig>,
21
22 pub routes: Vec<RouteConfig>,
24}
25
26#[derive(Debug, Clone, Default, Deserialize, Serialize)]
28pub struct ServerConfig {
29 pub http_port: Option<u16>,
31
32 pub https_port: Option<u16>,
34
35 pub worker_threads: Option<usize>,
37}
38
39#[derive(Debug, Clone, Deserialize, Serialize)]
41pub struct UpstreamConfig {
42 #[serde(default)]
44 pub protocol: Protocol,
45
46 pub host: Option<String>,
48
49 pub port: Option<u16>,
51
52 pub socket: Option<PathBuf>,
54
55 pub backends: Option<Vec<BackendConfig>>,
57
58 #[serde(default = "default_strategy")]
60 pub strategy: String,
61
62 pub sni: Option<String>,
64
65 pub process: Option<ProcessConfig>,
67}
68
69fn default_strategy() -> String {
70 "round_robin".to_string()
71}
72
73#[derive(Debug, Clone, Deserialize, Serialize)]
75pub struct BackendConfig {
76 #[serde(default)]
78 pub protocol: Protocol,
79
80 pub host: Option<String>,
82
83 pub port: Option<u16>,
85
86 pub socket: Option<PathBuf>,
88
89 pub sni: Option<String>,
91}
92
93#[derive(Debug, Clone, Deserialize, Serialize)]
95pub struct RouteConfig {
96 pub path_prefix: Option<String>,
98
99 pub path_exact: Option<String>,
101
102 pub path_regex: Option<String>,
104
105 pub host: Option<String>,
107
108 pub methods: Option<Vec<String>>,
110
111 pub header: Option<HeaderMatch>,
113
114 pub upstream: String,
116
117 #[serde(default)]
119 pub strip_prefix: bool,
120
121 pub rewrite: Option<String>,
123
124 #[serde(default)]
126 pub add_headers: Vec<HeaderValue>,
127
128 pub read_timeout_secs: Option<u64>,
130
131 pub connect_timeout_secs: Option<u64>,
133
134 pub description: Option<String>,
136}
137
138#[derive(Debug, Clone, Deserialize, Serialize)]
140pub struct HeaderMatch {
141 pub name: String,
142 pub value: String,
143}
144
145#[derive(Debug, Clone, Deserialize, Serialize)]
147pub struct HeaderValue {
148 pub name: String,
149 pub value: String,
150}
151
152#[derive(Debug, Clone, Deserialize, Serialize)]
154pub struct ProcessConfig {
155 pub command: String,
157
158 #[serde(default)]
160 pub args: Vec<String>,
161
162 #[serde(default)]
164 pub env: HashMap<String, String>,
165}
166
167impl TomlConfig {
168 pub fn from_file(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
170 let contents = std::fs::read_to_string(path)?;
171 let config: TomlConfig = toml::from_str(&contents)?;
172 Ok(config)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_toml_parsing() {
182 let toml_str = r#"
183 [server]
184 http_port = 8080
185 https_port = 443
186
187 [upstreams.rails]
188 protocol = "http"
189 host = "127.0.0.1"
190 port = 3000
191
192 [upstreams.cable]
193 protocol = "wss"
194 host = "127.0.0.1"
195 port = 3001
196
197 [[routes]]
198 path_prefix = "/cable"
199 upstream = "cable"
200
201 [[routes]]
202 path_prefix = "/"
203 upstream = "rails"
204 "#;
205
206 let config: TomlConfig = toml::from_str(toml_str).unwrap();
207
208 assert_eq!(config.server.http_port, Some(8080));
209 assert_eq!(config.upstreams.len(), 2);
210 assert_eq!(config.routes.len(), 2);
211
212 let rails = config.upstreams.get("rails").unwrap();
213 assert_eq!(rails.protocol, Protocol::Http);
214 assert_eq!(rails.port, Some(3000));
215 }
216
217 #[test]
218 fn test_load_balanced_upstream() {
219 let toml_str = r#"
220 [upstreams.rails]
221 protocol = "http"
222 strategy = "round_robin"
223
224 [[upstreams.rails.backends]]
225 host = "127.0.0.1"
226 port = 3000
227
228 [[upstreams.rails.backends]]
229 host = "127.0.0.1"
230 port = 3001
231
232 [[routes]]
233 path_prefix = "/"
234 upstream = "rails"
235 "#;
236
237 let config: TomlConfig = toml::from_str(toml_str).unwrap();
238 let rails = config.upstreams.get("rails").unwrap();
239
240 assert_eq!(rails.strategy, "round_robin");
241 assert!(rails.backends.is_some());
242 assert_eq!(rails.backends.as_ref().unwrap().len(), 2);
243 }
244}