1use core::fmt;
19use core::ops::Deref;
20use core::str::FromStr;
21
22use std::collections::HashMap;
23use std::io::ErrorKind;
24use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25use std::path::Path;
26
27use base64::engine::Engine as _;
28use base64::prelude::BASE64_STANDARD_NO_PAD;
29use http::Uri;
30use serde::{de, de::Deserializer, de::Visitor, Deserialize, Serialize};
31use tracing::{instrument, trace};
32use unicase::UniCase;
33
34const CORS_ALLOWED_ORIGINS: &[&str] = &[];
35const CORS_ALLOWED_METHODS: &[&str] = &["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"];
36const CORS_ALLOWED_HEADERS: &[&str] = &[
37 "accept",
38 "accept-language",
39 "content-type",
40 "content-language",
41];
42const CORS_EXPOSED_HEADERS: &[&str] = &[];
43const CORS_DEFAULT_MAX_AGE_SECS: u64 = 300;
44
45pub fn default_listen_address() -> SocketAddr {
46 (Ipv4Addr::UNSPECIFIED, 8000).into()
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct ServiceSettings {
51 #[serde(default = "default_listen_address")]
53 pub address: SocketAddr,
54 #[serde(default)]
56 pub cache_control: Option<String>,
57 #[serde(default)]
59 pub readonly_mode: Option<bool>,
60 pub cors_allowed_origins: Option<AllowedOrigins>,
62 pub cors_allowed_headers: Option<AllowedHeaders>,
63 pub cors_allowed_methods: Option<AllowedMethods>,
64 pub cors_exposed_headers: Option<ExposedHeaders>,
65 pub cors_max_age_secs: Option<u64>,
66 #[serde(default)]
68 pub tls_cert_file: Option<String>,
70 #[serde(default)]
71 pub tls_priv_key_file: Option<String>,
72 #[serde(default)]
76 pub timeout_ms: Option<u64>,
77 #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
79 #[serde(default)]
80 pub tls: Tls,
81 #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
82 #[serde(default)]
83 pub cors: Cors,
84 #[serde(default)]
85 pub disable_keepalive: Option<bool>,
86}
87
88impl Default for ServiceSettings {
89 fn default() -> ServiceSettings {
90 #[allow(deprecated)]
91 ServiceSettings {
92 address: default_listen_address(),
93 cors_allowed_origins: Some(AllowedOrigins::default()),
94 cors_allowed_headers: Some(AllowedHeaders::default()),
95 cors_allowed_methods: Some(AllowedMethods::default()),
96 cors_exposed_headers: Some(ExposedHeaders::default()),
97 cors_max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
98 tls_cert_file: None,
99 tls_priv_key_file: None,
100 timeout_ms: None,
101 cache_control: None,
102 readonly_mode: Some(false),
103 tls: Tls::default(),
104 cors: Cors::default(),
105 disable_keepalive: None,
106 }
107 }
108}
109
110impl ServiceSettings {
111 fn from_json(data: &str) -> Result<Self, HttpServerError> {
113 #[allow(deprecated)]
114 serde_json::from_str(data)
115 .map(|s: ServiceSettings| ServiceSettings {
118 address: s.address,
119 cache_control: s.cache_control,
120 readonly_mode: s.readonly_mode,
121 timeout_ms: s.timeout_ms,
122 tls_cert_file: s.tls_cert_file.or(s.tls.cert_file),
123 tls_priv_key_file: s.tls_priv_key_file.or(s.tls.priv_key_file),
124 cors_allowed_origins: s.cors_allowed_origins.or(s.cors.allowed_origins),
125 cors_allowed_headers: s.cors_allowed_headers.or(s.cors.allowed_headers),
126 cors_allowed_methods: s.cors_allowed_methods.or(s.cors.allowed_methods),
127 cors_exposed_headers: s.cors_exposed_headers.or(s.cors.exposed_headers),
128 cors_max_age_secs: s.cors_max_age_secs.or(s.cors.max_age_secs),
129 tls: Tls::default(),
130 cors: Cors::default(),
131 disable_keepalive: s.disable_keepalive,
132 })
133 .map_err(|e| HttpServerError::Settings(format!("invalid json: {e}")))
134 }
135
136 fn validate(&self) -> Result<(), HttpServerError> {
140 let mut errors = Vec::new();
141 match (&self.tls_cert_file, &self.tls_priv_key_file) {
143 (None, None) => {}
144 (Some(_), None) | (None, Some(_)) => {
145 errors.push(
146 "for tls, both 'tls_cert_file' and 'tls_priv_key_file' must be set".to_string(),
147 );
148 }
149 (Some(cert_file), Some(key_file)) => {
150 for f in &[("cert_file", &cert_file), ("priv_key_file", &key_file)] {
151 let path: &Path = f.1.as_ref();
152 if !path.is_file() {
153 errors.push(format!(
154 "missing tls_{} '{}'{}",
155 f.0,
156 &path.display(),
157 if path.is_absolute() {
158 ""
159 } else {
160 " : perhaps you should make the path absolute"
161 }
162 ));
163 }
164 }
165 }
166 }
167 if let Some(ref methods) = self.cors_allowed_methods {
168 for m in &methods.0 {
169 if http::Method::try_from(m.as_str()).is_err() {
170 errors.push(format!("invalid CORS method: '{m}'"));
171 }
172 }
173 }
174 if let Some(cache_control) = self.cache_control.as_ref() {
175 if http::HeaderValue::from_str(cache_control).is_err() {
176 errors.push(format!("Invalid Cache Control header : '{cache_control}'"));
177 }
178 }
179 if !errors.is_empty() {
180 Err(HttpServerError::Settings(format!(
181 "\nInvalid httpserver settings: \n{}\n",
182 errors.join("\n")
183 )))
184 } else {
185 Ok(())
186 }
187 }
188}
189
190#[derive(Debug, thiserror::Error)]
192pub enum HttpServerError {
193 #[error("invalid parameter: {0}")]
194 InvalidParameter(String),
195
196 #[error("problem reading settings: {0}")]
197 Settings(String),
198}
199
200#[instrument]
209pub fn load_settings(
210 default_address: Option<SocketAddr>,
211 values: &HashMap<String, String>,
212) -> Result<ServiceSettings, HttpServerError> {
213 trace!("load settings");
214 let values: HashMap<UniCase<&str>, &String> = values
217 .iter()
218 .map(|(k, v)| (UniCase::new(k.as_str()), v))
219 .collect();
220
221 if let Some(str) = values.get(&UniCase::new("config_b64")) {
222 let bytes = BASE64_STANDARD_NO_PAD
223 .decode(str)
224 .map_err(|e| HttpServerError::Settings(format!("invalid base64 encoding: {e}")))?;
225 return ServiceSettings::from_json(&String::from_utf8_lossy(&bytes));
226 }
227
228 if let Some(str) = values.get(&UniCase::new("config_json")) {
229 return ServiceSettings::from_json(str);
230 }
231
232 let mut settings = ServiceSettings::default();
233
234 if let Some(addr) = values.get(&UniCase::new("port")) {
236 let port = addr
237 .parse::<u16>()
238 .map_err(|_| HttpServerError::InvalidParameter(format!("Invalid port: {addr}")))?;
239 settings.address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port);
240 }
241 settings.address = values
243 .get(&UniCase::new("address"))
244 .map(|addr| {
245 SocketAddr::from_str(addr)
246 .map_err(|_| HttpServerError::InvalidParameter(format!("invalid address: {addr}")))
247 })
248 .transpose()?
249 .or(default_address)
250 .unwrap_or_else(default_listen_address);
251
252 if let Some(cache_control) = values.get(&UniCase::new("cache_control")) {
254 settings.cache_control = Some(cache_control.to_string());
255 }
256 if let Some(readonly_mode) = values.get(&UniCase::new("readonly_mode")) {
258 settings.readonly_mode = Some(readonly_mode.to_string().parse().unwrap_or(false));
259 }
260 if let Some(Ok(timeout_ms)) = values.get(&UniCase::new("timeout_ms")).map(|s| s.parse()) {
262 settings.timeout_ms = Some(timeout_ms)
263 }
264
265 if let Some(tls_cert_file) = values.get(&UniCase::new("tls_cert_file")) {
267 settings.tls_cert_file = Some(tls_cert_file.to_string());
268 }
269 if let Some(tls_priv_key_file) = values.get(&UniCase::new("tls_priv_key_file")) {
270 settings.tls_priv_key_file = Some(tls_priv_key_file.to_string());
271 }
272
273 if let Some(cors_allowed_origins) = values.get(&UniCase::new("cors_allowed_origins")) {
275 let origins: Vec<CorsOrigin> = serde_json::from_str(cors_allowed_origins)
276 .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_origins: {e}")))?;
277 settings.cors_allowed_origins = Some(AllowedOrigins(origins));
278 }
279 if let Some(cors_allowed_headers) = values.get(&UniCase::new("cors_allowed_headers")) {
280 let headers: Vec<String> = serde_json::from_str(cors_allowed_headers)
281 .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_headers: {e}")))?;
282 settings.cors_allowed_headers = Some(AllowedHeaders(headers));
283 }
284 if let Some(cors_allowed_methods) = values.get(&UniCase::new("cors_allowed_methods")) {
285 let methods: Vec<String> = serde_json::from_str(cors_allowed_methods)
286 .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_methods: {e}")))?;
287 settings.cors_allowed_methods = Some(AllowedMethods(methods));
288 }
289 if let Some(cors_exposed_headers) = values.get(&UniCase::new("cors_exposed_headers")) {
290 let headers: Vec<String> = serde_json::from_str(cors_exposed_headers)
291 .map_err(|e| HttpServerError::Settings(format!("invalid cors_exposed_headers: {e}")))?;
292 settings.cors_exposed_headers = Some(ExposedHeaders(headers));
293 }
294 if let Some(cors_max_age_secs) = values.get(&UniCase::new("cors_max_age_secs")) {
295 let max_age_secs: u64 = cors_max_age_secs.parse().map_err(|_| {
296 HttpServerError::InvalidParameter("Invalid cors_max_age_secs".to_string())
297 })?;
298 settings.cors_max_age_secs = Some(max_age_secs);
299 }
300 if let Some(disable_keepalive) = values.get(&UniCase::new("disable_keepalive")) {
301 settings.disable_keepalive = Some(disable_keepalive.parse().unwrap_or(false));
302 }
303
304 settings.validate()?;
305 Ok(settings)
306}
307
308#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
309pub struct Tls {
310 pub cert_file: Option<String>,
312 pub priv_key_file: Option<String>,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
316pub struct Cors {
317 pub allowed_origins: Option<AllowedOrigins>,
318 pub allowed_headers: Option<AllowedHeaders>,
319 pub allowed_methods: Option<AllowedMethods>,
320 pub exposed_headers: Option<ExposedHeaders>,
321 pub max_age_secs: Option<u64>,
322}
323
324impl Default for Cors {
325 fn default() -> Self {
326 Cors {
327 allowed_origins: Some(AllowedOrigins::default()),
328 allowed_headers: Some(AllowedHeaders::default()),
329 allowed_methods: Some(AllowedMethods::default()),
330 exposed_headers: Some(ExposedHeaders::default()),
331 max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
332 }
333 }
334}
335
336#[derive(Debug, Clone, Default, Serialize, PartialEq, Eq)]
337pub struct CorsOrigin(String);
338
339#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
340pub struct AllowedOrigins(Vec<CorsOrigin>);
341
342#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
343pub struct AllowedHeaders(Vec<String>);
344
345#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
346pub struct AllowedMethods(Vec<String>);
347
348#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
349pub struct ExposedHeaders(Vec<String>);
350
351impl<'de> Deserialize<'de> for CorsOrigin {
352 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
353 where
354 D: Deserializer<'de>,
355 {
356 struct CorsOriginVisitor;
357 impl Visitor<'_> for CorsOriginVisitor {
358 type Value = CorsOrigin;
359
360 fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
361 write!(fmt, "an origin in format http[s]://example.com[:3000]",)
362 }
363
364 fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
365 where
366 E: de::Error,
367 {
368 CorsOrigin::from_str(v).map_err(E::custom)
369 }
370 }
371 deserializer.deserialize_str(CorsOriginVisitor)
372 }
373}
374
375impl FromStr for CorsOrigin {
376 type Err = std::io::Error;
377
378 fn from_str(origin: &str) -> Result<Self, Self::Err> {
379 let uri = Uri::from_str(origin).map_err(|invalid_uri| {
380 std::io::Error::new(
381 ErrorKind::InvalidInput,
382 format!("Invalid uri: {origin}.\n{invalid_uri}"),
383 )
384 })?;
385 if let Some(s) = uri.scheme_str() {
386 if s != "http" && s != "https" {
387 return Err(std::io::Error::new(
388 ErrorKind::InvalidInput,
389 format!(
390 "Cors origin invalid schema {}, only [http] and [https] are supported: ",
391 uri.scheme_str().unwrap()
392 ),
393 ));
394 }
395 } else {
396 return Err(std::io::Error::new(
397 ErrorKind::InvalidInput,
398 "Cors origin missing schema, only [http] or [https] are supported",
399 ));
400 }
401
402 if let Some(p) = uri.path_and_query() {
403 if p.as_str() != "/" {
404 return Err(std::io::Error::new(
405 ErrorKind::InvalidInput,
406 format!("Invalid value {} in cors schema.", p.as_str()),
407 ));
408 }
409 }
410 Ok(CorsOrigin(origin.trim_end_matches('/').to_owned()))
411 }
412}
413
414impl AsRef<str> for CorsOrigin {
415 fn as_ref(&self) -> &str {
416 &self.0
417 }
418}
419
420impl Deref for AllowedOrigins {
421 type Target = Vec<CorsOrigin>;
422
423 fn deref(&self) -> &Self::Target {
424 &self.0
425 }
426}
427
428impl Default for AllowedOrigins {
429 fn default() -> Self {
430 AllowedOrigins(
431 CORS_ALLOWED_ORIGINS
432 .iter()
433 .map(|s| CorsOrigin((*s).to_string()))
434 .collect::<Vec<_>>(),
435 )
436 }
437}
438
439impl Deref for AllowedHeaders {
440 type Target = Vec<String>;
441
442 fn deref(&self) -> &Self::Target {
443 &self.0
444 }
445}
446
447impl Default for AllowedHeaders {
448 fn default() -> Self {
449 AllowedHeaders(from_defaults(CORS_ALLOWED_HEADERS))
450 }
451}
452
453impl Default for AllowedMethods {
454 fn default() -> Self {
455 AllowedMethods(from_defaults(CORS_ALLOWED_METHODS))
456 }
457}
458
459impl Deref for AllowedMethods {
460 type Target = Vec<String>;
461
462 fn deref(&self) -> &Self::Target {
463 &self.0
464 }
465}
466
467impl Deref for ExposedHeaders {
468 type Target = Vec<String>;
469
470 fn deref(&self) -> &Self::Target {
471 &self.0
472 }
473}
474
475impl Default for ExposedHeaders {
476 fn default() -> Self {
477 ExposedHeaders(
478 CORS_EXPOSED_HEADERS
479 .iter()
480 .map(|s| (*s).to_string())
481 .collect::<Vec<_>>(),
482 )
483 }
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
487#[serde(rename_all = "UPPERCASE")]
488pub enum HttpMethod {
489 Get,
490 Post,
491 Put,
492 Delete,
493 Head,
494 Options,
495 Connect,
496 Patch,
497 Trace,
498}
499
500impl FromStr for HttpMethod {
501 type Err = std::io::Error;
502
503 fn from_str(s: &str) -> Result<Self, Self::Err> {
504 match s.to_uppercase().as_str() {
505 "GET" => Ok(Self::Get),
506 "PUT" => Ok(Self::Put),
507 "POST" => Ok(Self::Post),
508 "DELETE" => Ok(Self::Delete),
509 "HEAD" => Ok(Self::Head),
510 "OPTIONS" => Ok(Self::Options),
511 "CONNECT" => Ok(Self::Connect),
512 "PATCH" => Ok(Self::Patch),
513 "TRACE" => Ok(Self::Trace),
514 _ => Err(std::io::Error::new(
515 std::io::ErrorKind::InvalidData,
516 format!("{s} is not a valid http method"),
517 )),
518 }
519 }
520}
521
522fn from_defaults<'d, T>(d: &[&'d str]) -> Vec<T>
524where
525 T: std::convert::From<&'d str>,
526{
527 d.iter().map(|s| T::from(*s)).collect::<Vec<_>>()
529}
530
531#[cfg(test)]
532mod test {
533 use std::str::FromStr;
534
535 use crate::settings::{CorsOrigin, ServiceSettings};
536
537 const GOOD_ORIGINS: &[&str] = &[
538 "https://www.example.com",
540 "https://www.example.com:1000",
541 "http://localhost",
542 "http://localhost:8080",
543 "http://127.0.0.1",
544 "http://127.0.0.1:8080",
545 "https://:8080",
546 ];
547
548 const BAD_ORIGINS: &[&str] = &[
549 "ftp://www.example.com", "localhost",
552 "127.0.0.1",
553 "127.0.0.1:8080",
554 ":8080",
555 "/path/file.txt",
556 "http:",
557 "https://",
558 ];
559
560 #[test]
561 fn settings_init() {
562 let s = ServiceSettings::default();
563 assert!(s.address.is_ipv4());
564 assert!(s.cors_allowed_methods.is_some());
565 assert!(s.cors_allowed_origins.is_some());
566 assert!(s.cors_allowed_origins.unwrap().0.is_empty());
567 }
568
569 #[test]
570 fn settings_json() {
571 let json = r#"{
572 "cors": {
573 "allowed_headers": [ "X-Cookies" ]
574 }
575 }"#;
576
577 let s = ServiceSettings::from_json(json).expect("parse_json");
578 assert_eq!(s.cors_allowed_headers.as_ref().unwrap().0.len(), 1);
579 assert_eq!(
580 s.cors_allowed_headers.as_ref().unwrap().0.first().unwrap(),
581 "X-Cookies"
582 );
583 }
584
585 #[test]
586 fn origins_deserialize() {
587 for valid in GOOD_ORIGINS {
589 let o = serde_json::from_value::<CorsOrigin>(serde_json::Value::String(
590 (*valid).to_string(),
591 ));
592 assert!(o.is_ok(), "from_value '{valid}'");
593
594 assert_eq!(&o.unwrap().0, valid);
596 }
597 }
598
599 #[test]
600 fn origins_from_str() {
601 for &valid in GOOD_ORIGINS {
603 let o = CorsOrigin::from_str(valid);
604 println!("{valid}: {o:?}");
605 assert!(o.is_ok(), "from_str '{valid}'");
606
607 assert_eq!(&o.unwrap().0, valid);
609 }
610 }
611
612 #[test]
613 fn origins_negative() {
614 for bad in BAD_ORIGINS {
615 let o =
616 serde_json::from_value::<CorsOrigin>(serde_json::Value::String((*bad).to_string()));
617 println!("{bad}: {o:?}");
618 assert!(o.is_err(), "from_value '{bad}' (expect err)");
619
620 let o = serde_json::from_str::<CorsOrigin>(bad);
621 println!("{bad}: {o:?}");
622 assert!(o.is_err(), "from_str '{bad}' (expect err)");
623 }
624 }
625}