wasmcloud_provider_http_server/
settings.rs

1//! Configuration settings for [`HttpServerProvider`](crate::HttpServerProvider).
2//! The "values" map in the component link definition may contain
3//! one or more of the following keys,
4//! which determine how the configuration is parsed.
5//!
6//! For the key...
7///   `config_file`:       load configuration from file name.
8///                      Interprets file as json or toml, based on file extension.
9///   `config_b64`:        Configuration is a base64-encoded json string
10///   `config_json`:       Configuration is a raw json string
11///
12/// If no configuration is provided, the default settings below will be used:
13/// - TLS is disabled
14/// - CORS allows all hosts(origins), most methods, and common headers
15///   (see constants below).
16/// - Default listener is bound to 127.0.0.1 port 8000.
17///
18use core::fmt;
19use core::ops::Deref;
20use core::str::FromStr;
21
22use std::collections::HashMap;
23use std::io::ErrorKind;
24use std::net::{IpAddr, Ipv4Addr, SocketAddr};
25use std::path::Path;
26
27use base64::engine::Engine as _;
28use base64::prelude::BASE64_STANDARD_NO_PAD;
29use http::Uri;
30use serde::{de, de::Deserializer, de::Visitor, Deserialize, Serialize};
31use tracing::{instrument, trace};
32use unicase::UniCase;
33
34const CORS_ALLOWED_ORIGINS: &[&str] = &[];
35const CORS_ALLOWED_METHODS: &[&str] = &["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"];
36const CORS_ALLOWED_HEADERS: &[&str] = &[
37    "accept",
38    "accept-language",
39    "content-type",
40    "content-language",
41];
42const CORS_EXPOSED_HEADERS: &[&str] = &[];
43const CORS_DEFAULT_MAX_AGE_SECS: u64 = 300;
44
45pub fn default_listen_address() -> SocketAddr {
46    (Ipv4Addr::UNSPECIFIED, 8000).into()
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct ServiceSettings {
51    /// Bind address
52    #[serde(default = "default_listen_address")]
53    pub address: SocketAddr,
54    /// cache control options
55    #[serde(default)]
56    pub cache_control: Option<String>,
57    /// Flag for read only mode
58    #[serde(default)]
59    pub readonly_mode: Option<bool>,
60    // cors config
61    pub cors_allowed_origins: Option<AllowedOrigins>,
62    pub cors_allowed_headers: Option<AllowedHeaders>,
63    pub cors_allowed_methods: Option<AllowedMethods>,
64    pub cors_exposed_headers: Option<ExposedHeaders>,
65    pub cors_max_age_secs: Option<u64>,
66    // tls config
67    #[serde(default)]
68    /// path to server X.509 cert chain file. Must be PEM-encoded
69    pub tls_cert_file: Option<String>,
70    #[serde(default)]
71    pub tls_priv_key_file: Option<String>,
72    /// Rpc timeout - how long (milliseconds) to wait for component's response
73    /// before returning a status 503 to the http client
74    /// If not set, uses the system-wide rpc timeout
75    #[serde(default)]
76    pub timeout_ms: Option<u64>,
77    // DEPRECATED due to the nested struct being poorly supported by wasmCloud config
78    #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
79    #[serde(default)]
80    pub tls: Tls,
81    #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
82    #[serde(default)]
83    pub cors: Cors,
84    #[serde(default)]
85    pub disable_keepalive: Option<bool>,
86}
87
88impl Default for ServiceSettings {
89    fn default() -> ServiceSettings {
90        #[allow(deprecated)]
91        ServiceSettings {
92            address: default_listen_address(),
93            cors_allowed_origins: Some(AllowedOrigins::default()),
94            cors_allowed_headers: Some(AllowedHeaders::default()),
95            cors_allowed_methods: Some(AllowedMethods::default()),
96            cors_exposed_headers: Some(ExposedHeaders::default()),
97            cors_max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
98            tls_cert_file: None,
99            tls_priv_key_file: None,
100            timeout_ms: None,
101            cache_control: None,
102            readonly_mode: Some(false),
103            tls: Tls::default(),
104            cors: Cors::default(),
105            disable_keepalive: None,
106        }
107    }
108}
109
110impl ServiceSettings {
111    /// load settings from json, flattening nested fields
112    fn from_json(data: &str) -> Result<Self, HttpServerError> {
113        #[allow(deprecated)]
114        serde_json::from_str(data)
115            // For backwards compatibility, we can pull the values from the `tls` and `cors` fields
116            // and merge them into the top-level fields.
117            .map(|s: ServiceSettings| ServiceSettings {
118                address: s.address,
119                cache_control: s.cache_control,
120                readonly_mode: s.readonly_mode,
121                timeout_ms: s.timeout_ms,
122                tls_cert_file: s.tls_cert_file.or(s.tls.cert_file),
123                tls_priv_key_file: s.tls_priv_key_file.or(s.tls.priv_key_file),
124                cors_allowed_origins: s.cors_allowed_origins.or(s.cors.allowed_origins),
125                cors_allowed_headers: s.cors_allowed_headers.or(s.cors.allowed_headers),
126                cors_allowed_methods: s.cors_allowed_methods.or(s.cors.allowed_methods),
127                cors_exposed_headers: s.cors_exposed_headers.or(s.cors.exposed_headers),
128                cors_max_age_secs: s.cors_max_age_secs.or(s.cors.max_age_secs),
129                tls: Tls::default(),
130                cors: Cors::default(),
131                disable_keepalive: s.disable_keepalive,
132            })
133            .map_err(|e| HttpServerError::Settings(format!("invalid json: {e}")))
134    }
135
136    /// perform additional validation checks on settings.
137    /// Several checks have already been done during deserialization.
138    /// All errors found are combined into a single error message
139    fn validate(&self) -> Result<(), HttpServerError> {
140        let mut errors = Vec::new();
141        // 1. make sure tls config is valid
142        match (&self.tls_cert_file, &self.tls_priv_key_file) {
143            (None, None) => {}
144            (Some(_), None) | (None, Some(_)) => {
145                errors.push(
146                    "for tls, both 'tls_cert_file' and 'tls_priv_key_file' must be set".to_string(),
147                );
148            }
149            (Some(cert_file), Some(key_file)) => {
150                for f in &[("cert_file", &cert_file), ("priv_key_file", &key_file)] {
151                    let path: &Path = f.1.as_ref();
152                    if !path.is_file() {
153                        errors.push(format!(
154                            "missing tls_{} '{}'{}",
155                            f.0,
156                            &path.display(),
157                            if path.is_absolute() {
158                                ""
159                            } else {
160                                " : perhaps you should make the path absolute"
161                            }
162                        ));
163                    }
164                }
165            }
166        }
167        if let Some(ref methods) = self.cors_allowed_methods {
168            for m in &methods.0 {
169                if http::Method::try_from(m.as_str()).is_err() {
170                    errors.push(format!("invalid CORS method: '{m}'"));
171                }
172            }
173        }
174        if let Some(cache_control) = self.cache_control.as_ref() {
175            if http::HeaderValue::from_str(cache_control).is_err() {
176                errors.push(format!("Invalid Cache Control header : '{cache_control}'"));
177            }
178        }
179        if !errors.is_empty() {
180            Err(HttpServerError::Settings(format!(
181                "\nInvalid httpserver settings: \n{}\n",
182                errors.join("\n")
183            )))
184        } else {
185            Ok(())
186        }
187    }
188}
189
190/// Errors generated by this HTTP server
191#[derive(Debug, thiserror::Error)]
192pub enum HttpServerError {
193    #[error("invalid parameter: {0}")]
194    InvalidParameter(String),
195
196    #[error("problem reading settings: {0}")]
197    Settings(String),
198}
199
200/// Load settings provides a flexible means for loading configuration.
201/// Return value is any structure with Deserialize, or for example, HashMap<String,String>
202///   config_b64:  base64-encoded json string
203///   config_json: raw json string
204/// Also accept "address" (a string representing SocketAddr) and "port", a localhost port
205/// If more than one key is provided, they are processed in the order above.
206///   (later names override earlier names in the list)
207///
208#[instrument]
209pub fn load_settings(
210    default_address: Option<SocketAddr>,
211    values: &HashMap<String, String>,
212) -> Result<ServiceSettings, HttpServerError> {
213    trace!("load settings");
214    // Allow keys to be case insensitive, as an accommodation
215    // for the lost souls who prefer sPoNgEbOb CaSe variable names.
216    let values: HashMap<UniCase<&str>, &String> = values
217        .iter()
218        .map(|(k, v)| (UniCase::new(k.as_str()), v))
219        .collect();
220
221    if let Some(str) = values.get(&UniCase::new("config_b64")) {
222        let bytes = BASE64_STANDARD_NO_PAD
223            .decode(str)
224            .map_err(|e| HttpServerError::Settings(format!("invalid base64 encoding: {e}")))?;
225        return ServiceSettings::from_json(&String::from_utf8_lossy(&bytes));
226    }
227
228    if let Some(str) = values.get(&UniCase::new("config_json")) {
229        return ServiceSettings::from_json(str);
230    }
231
232    let mut settings = ServiceSettings::default();
233
234    // accept port, for compatibility with previous implementations
235    if let Some(addr) = values.get(&UniCase::new("port")) {
236        let port = addr
237            .parse::<u16>()
238            .map_err(|_| HttpServerError::InvalidParameter(format!("Invalid port: {addr}")))?;
239        settings.address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port);
240    }
241    // accept address as value parameter
242    settings.address = values
243        .get(&UniCase::new("address"))
244        .map(|addr| {
245            SocketAddr::from_str(addr)
246                .map_err(|_| HttpServerError::InvalidParameter(format!("invalid address: {addr}")))
247        })
248        .transpose()?
249        .or(default_address)
250        .unwrap_or_else(default_listen_address);
251
252    // accept cache-control header values
253    if let Some(cache_control) = values.get(&UniCase::new("cache_control")) {
254        settings.cache_control = Some(cache_control.to_string());
255    }
256    // accept read only mode flag
257    if let Some(readonly_mode) = values.get(&UniCase::new("readonly_mode")) {
258        settings.readonly_mode = Some(readonly_mode.to_string().parse().unwrap_or(false));
259    }
260    // accept timeout_ms flag
261    if let Some(Ok(timeout_ms)) = values.get(&UniCase::new("timeout_ms")).map(|s| s.parse()) {
262        settings.timeout_ms = Some(timeout_ms)
263    }
264
265    // TLS
266    if let Some(tls_cert_file) = values.get(&UniCase::new("tls_cert_file")) {
267        settings.tls_cert_file = Some(tls_cert_file.to_string());
268    }
269    if let Some(tls_priv_key_file) = values.get(&UniCase::new("tls_priv_key_file")) {
270        settings.tls_priv_key_file = Some(tls_priv_key_file.to_string());
271    }
272
273    // CORS
274    if let Some(cors_allowed_origins) = values.get(&UniCase::new("cors_allowed_origins")) {
275        let origins: Vec<CorsOrigin> = serde_json::from_str(cors_allowed_origins)
276            .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_origins: {e}")))?;
277        settings.cors_allowed_origins = Some(AllowedOrigins(origins));
278    }
279    if let Some(cors_allowed_headers) = values.get(&UniCase::new("cors_allowed_headers")) {
280        let headers: Vec<String> = serde_json::from_str(cors_allowed_headers)
281            .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_headers: {e}")))?;
282        settings.cors_allowed_headers = Some(AllowedHeaders(headers));
283    }
284    if let Some(cors_allowed_methods) = values.get(&UniCase::new("cors_allowed_methods")) {
285        let methods: Vec<String> = serde_json::from_str(cors_allowed_methods)
286            .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_methods: {e}")))?;
287        settings.cors_allowed_methods = Some(AllowedMethods(methods));
288    }
289    if let Some(cors_exposed_headers) = values.get(&UniCase::new("cors_exposed_headers")) {
290        let headers: Vec<String> = serde_json::from_str(cors_exposed_headers)
291            .map_err(|e| HttpServerError::Settings(format!("invalid cors_exposed_headers: {e}")))?;
292        settings.cors_exposed_headers = Some(ExposedHeaders(headers));
293    }
294    if let Some(cors_max_age_secs) = values.get(&UniCase::new("cors_max_age_secs")) {
295        let max_age_secs: u64 = cors_max_age_secs.parse().map_err(|_| {
296            HttpServerError::InvalidParameter("Invalid cors_max_age_secs".to_string())
297        })?;
298        settings.cors_max_age_secs = Some(max_age_secs);
299    }
300    if let Some(disable_keepalive) = values.get(&UniCase::new("disable_keepalive")) {
301        settings.disable_keepalive = Some(disable_keepalive.parse().unwrap_or(false));
302    }
303
304    settings.validate()?;
305    Ok(settings)
306}
307
308#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
309pub struct Tls {
310    /// path to server X.509 cert chain file. Must be PEM-encoded
311    pub cert_file: Option<String>,
312    pub priv_key_file: Option<String>,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
316pub struct Cors {
317    pub allowed_origins: Option<AllowedOrigins>,
318    pub allowed_headers: Option<AllowedHeaders>,
319    pub allowed_methods: Option<AllowedMethods>,
320    pub exposed_headers: Option<ExposedHeaders>,
321    pub max_age_secs: Option<u64>,
322}
323
324impl Default for Cors {
325    fn default() -> Self {
326        Cors {
327            allowed_origins: Some(AllowedOrigins::default()),
328            allowed_headers: Some(AllowedHeaders::default()),
329            allowed_methods: Some(AllowedMethods::default()),
330            exposed_headers: Some(ExposedHeaders::default()),
331            max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
332        }
333    }
334}
335
336#[derive(Debug, Clone, Default, Serialize, PartialEq, Eq)]
337pub struct CorsOrigin(String);
338
339#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
340pub struct AllowedOrigins(Vec<CorsOrigin>);
341
342#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
343pub struct AllowedHeaders(Vec<String>);
344
345#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
346pub struct AllowedMethods(Vec<String>);
347
348#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
349pub struct ExposedHeaders(Vec<String>);
350
351impl<'de> Deserialize<'de> for CorsOrigin {
352    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
353    where
354        D: Deserializer<'de>,
355    {
356        struct CorsOriginVisitor;
357        impl Visitor<'_> for CorsOriginVisitor {
358            type Value = CorsOrigin;
359
360            fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
361                write!(fmt, "an origin in format http[s]://example.com[:3000]",)
362            }
363
364            fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
365            where
366                E: de::Error,
367            {
368                CorsOrigin::from_str(v).map_err(E::custom)
369            }
370        }
371        deserializer.deserialize_str(CorsOriginVisitor)
372    }
373}
374
375impl FromStr for CorsOrigin {
376    type Err = std::io::Error;
377
378    fn from_str(origin: &str) -> Result<Self, Self::Err> {
379        let uri = Uri::from_str(origin).map_err(|invalid_uri| {
380            std::io::Error::new(
381                ErrorKind::InvalidInput,
382                format!("Invalid uri: {origin}.\n{invalid_uri}"),
383            )
384        })?;
385        if let Some(s) = uri.scheme_str() {
386            if s != "http" && s != "https" {
387                return Err(std::io::Error::new(
388                    ErrorKind::InvalidInput,
389                    format!(
390                        "Cors origin invalid schema {}, only [http] and [https] are supported: ",
391                        uri.scheme_str().unwrap()
392                    ),
393                ));
394            }
395        } else {
396            return Err(std::io::Error::new(
397                ErrorKind::InvalidInput,
398                "Cors origin missing schema, only [http] or [https] are supported",
399            ));
400        }
401
402        if let Some(p) = uri.path_and_query() {
403            if p.as_str() != "/" {
404                return Err(std::io::Error::new(
405                    ErrorKind::InvalidInput,
406                    format!("Invalid value {} in cors schema.", p.as_str()),
407                ));
408            }
409        }
410        Ok(CorsOrigin(origin.trim_end_matches('/').to_owned()))
411    }
412}
413
414impl AsRef<str> for CorsOrigin {
415    fn as_ref(&self) -> &str {
416        &self.0
417    }
418}
419
420impl Deref for AllowedOrigins {
421    type Target = Vec<CorsOrigin>;
422
423    fn deref(&self) -> &Self::Target {
424        &self.0
425    }
426}
427
428impl Default for AllowedOrigins {
429    fn default() -> Self {
430        AllowedOrigins(
431            CORS_ALLOWED_ORIGINS
432                .iter()
433                .map(|s| CorsOrigin((*s).to_string()))
434                .collect::<Vec<_>>(),
435        )
436    }
437}
438
439impl Deref for AllowedHeaders {
440    type Target = Vec<String>;
441
442    fn deref(&self) -> &Self::Target {
443        &self.0
444    }
445}
446
447impl Default for AllowedHeaders {
448    fn default() -> Self {
449        AllowedHeaders(from_defaults(CORS_ALLOWED_HEADERS))
450    }
451}
452
453impl Default for AllowedMethods {
454    fn default() -> Self {
455        AllowedMethods(from_defaults(CORS_ALLOWED_METHODS))
456    }
457}
458
459impl Deref for AllowedMethods {
460    type Target = Vec<String>;
461
462    fn deref(&self) -> &Self::Target {
463        &self.0
464    }
465}
466
467impl Deref for ExposedHeaders {
468    type Target = Vec<String>;
469
470    fn deref(&self) -> &Self::Target {
471        &self.0
472    }
473}
474
475impl Default for ExposedHeaders {
476    fn default() -> Self {
477        ExposedHeaders(
478            CORS_EXPOSED_HEADERS
479                .iter()
480                .map(|s| (*s).to_string())
481                .collect::<Vec<_>>(),
482        )
483    }
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
487#[serde(rename_all = "UPPERCASE")]
488pub enum HttpMethod {
489    Get,
490    Post,
491    Put,
492    Delete,
493    Head,
494    Options,
495    Connect,
496    Patch,
497    Trace,
498}
499
500impl FromStr for HttpMethod {
501    type Err = std::io::Error;
502
503    fn from_str(s: &str) -> Result<Self, Self::Err> {
504        match s.to_uppercase().as_str() {
505            "GET" => Ok(Self::Get),
506            "PUT" => Ok(Self::Put),
507            "POST" => Ok(Self::Post),
508            "DELETE" => Ok(Self::Delete),
509            "HEAD" => Ok(Self::Head),
510            "OPTIONS" => Ok(Self::Options),
511            "CONNECT" => Ok(Self::Connect),
512            "PATCH" => Ok(Self::Patch),
513            "TRACE" => Ok(Self::Trace),
514            _ => Err(std::io::Error::new(
515                std::io::ErrorKind::InvalidData,
516                format!("{s} is not a valid http method"),
517            )),
518        }
519    }
520}
521
522/// convert array of &str into array of T if T is From<&str>
523fn from_defaults<'d, T>(d: &[&'d str]) -> Vec<T>
524where
525    T: std::convert::From<&'d str>,
526{
527    // unwrap ok here because this is only used for default values
528    d.iter().map(|s| T::from(*s)).collect::<Vec<_>>()
529}
530
531#[cfg(test)]
532mod test {
533    use std::str::FromStr;
534
535    use crate::settings::{CorsOrigin, ServiceSettings};
536
537    const GOOD_ORIGINS: &[&str] = &[
538        // origins that should be parsed correctly
539        "https://www.example.com",
540        "https://www.example.com:1000",
541        "http://localhost",
542        "http://localhost:8080",
543        "http://127.0.0.1",
544        "http://127.0.0.1:8080",
545        "https://:8080",
546    ];
547
548    const BAD_ORIGINS: &[&str] = &[
549        // invalid origin syntax
550        "ftp://www.example.com", // only http,https allowed
551        "localhost",
552        "127.0.0.1",
553        "127.0.0.1:8080",
554        ":8080",
555        "/path/file.txt",
556        "http:",
557        "https://",
558    ];
559
560    #[test]
561    fn settings_init() {
562        let s = ServiceSettings::default();
563        assert!(s.address.is_ipv4());
564        assert!(s.cors_allowed_methods.is_some());
565        assert!(s.cors_allowed_origins.is_some());
566        assert!(s.cors_allowed_origins.unwrap().0.is_empty());
567    }
568
569    #[test]
570    fn settings_json() {
571        let json = r#"{
572        "cors": {
573            "allowed_headers": [ "X-Cookies" ]
574         }
575         }"#;
576
577        let s = ServiceSettings::from_json(json).expect("parse_json");
578        assert_eq!(s.cors_allowed_headers.as_ref().unwrap().0.len(), 1);
579        assert_eq!(
580            s.cors_allowed_headers.as_ref().unwrap().0.first().unwrap(),
581            "X-Cookies"
582        );
583    }
584
585    #[test]
586    fn origins_deserialize() {
587        // test CorsOrigin
588        for valid in GOOD_ORIGINS {
589            let o = serde_json::from_value::<CorsOrigin>(serde_json::Value::String(
590                (*valid).to_string(),
591            ));
592            assert!(o.is_ok(), "from_value '{valid}'");
593
594            // test as_ref()
595            assert_eq!(&o.unwrap().0, valid);
596        }
597    }
598
599    #[test]
600    fn origins_from_str() {
601        // test CorsOrigin
602        for &valid in GOOD_ORIGINS {
603            let o = CorsOrigin::from_str(valid);
604            println!("{valid}: {o:?}");
605            assert!(o.is_ok(), "from_str '{valid}'");
606
607            // test as_ref()
608            assert_eq!(&o.unwrap().0, valid);
609        }
610    }
611
612    #[test]
613    fn origins_negative() {
614        for bad in BAD_ORIGINS {
615            let o =
616                serde_json::from_value::<CorsOrigin>(serde_json::Value::String((*bad).to_string()));
617            println!("{bad}: {o:?}");
618            assert!(o.is_err(), "from_value '{bad}' (expect err)");
619
620            let o = serde_json::from_str::<CorsOrigin>(bad);
621            println!("{bad}: {o:?}");
622            assert!(o.is_err(), "from_str '{bad}' (expect err)");
623        }
624    }
625}