spring_web/
config.rs

1use schemars::JsonSchema;
2use serde::Deserialize;
3use spring::config::Configurable;
4use std::net::{IpAddr, Ipv4Addr};
5use tracing::Level;
6
7spring::submit_config_schema!("web", WebConfig);
8
9#[cfg(feature = "socket_io")]
10spring::submit_config_schema!("socket_io", SocketIOConfig);
11
12/// spring-web Config
13#[derive(Debug, Configurable, JsonSchema, Deserialize)]
14#[config_prefix = "web"]
15pub struct WebConfig {
16    #[serde(flatten)]
17    pub(crate) server: ServerConfig,
18    #[cfg(feature = "openapi")]
19    pub(crate) openapi: OpenApiConfig,
20    pub(crate) middlewares: Option<Middlewares>,
21}
22
23#[derive(Debug, Clone, JsonSchema, Deserialize)]
24pub struct ServerConfig {
25    #[serde(default = "default_binding")]
26    pub(crate) binding: IpAddr,
27    #[serde(default = "default_port")]
28    pub(crate) port: u16,
29    #[serde(default)]
30    pub(crate) connect_info: bool,
31    #[serde(default = "default_true")]
32    pub(crate) graceful: bool,
33}
34
35#[cfg(feature = "openapi")]
36#[derive(Debug, Clone, JsonSchema, Deserialize)]
37pub struct OpenApiConfig {
38    #[serde(default = "default_doc_prefix")]
39    pub(crate) doc_prefix: String,
40    #[serde(default)]
41    pub(crate) info: aide::openapi::Info,
42}
43
44fn default_binding() -> IpAddr {
45    IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
46}
47
48fn default_port() -> u16 {
49    8080
50}
51
52fn default_true() -> bool {
53    true
54}
55
56#[cfg(feature = "openapi")]
57fn default_doc_prefix() -> String {
58    "/docs".into()
59}
60
61/// Server middleware configuration structure.
62#[derive(Debug, Clone, JsonSchema, Deserialize)]
63pub struct Middlewares {
64    /// Middleware that enable compression for the response.
65    pub compression: Option<EnableMiddleware>,
66    /// Middleware that limit the payload request.
67    pub limit_payload: Option<LimitPayloadMiddleware>,
68    /// Middleware that improve the tracing logger and adding trace id for each
69    /// request.
70    pub logger: Option<TraceLoggerMiddleware>,
71    /// catch any code panic and log the error.
72    pub catch_panic: Option<EnableMiddleware>,
73    /// Setting a global timeout for the requests
74    pub timeout_request: Option<TimeoutRequestMiddleware>,
75    /// Setting cors configuration
76    pub cors: Option<CorsMiddleware>,
77    /// Serving static assets
78    #[serde(rename = "static")]
79    pub static_assets: Option<StaticAssetsMiddleware>,
80}
81
82/// Static asset middleware configuration
83#[derive(Debug, Clone, JsonSchema, Deserialize)]
84pub struct StaticAssetsMiddleware {
85    /// toggle enable
86    pub enable: bool,
87    /// Check that assets must exist on disk
88    #[serde(default = "bool::default")]
89    pub must_exist: bool,
90    /// Fallback page for a case when no asset exists (404). Useful for SPA
91    /// (single page app) where routes are virtual.
92    #[serde(default = "default_fallback")]
93    pub fallback: String,
94    /// Enable `precompressed_gzip`
95    #[serde(default = "bool::default")]
96    pub precompressed: bool,
97    /// Uri for the assets
98    #[serde(default = "default_assets_uri")]
99    pub uri: String,
100    /// Path for the assets
101    #[serde(default = "default_assets_path")]
102    pub path: String,
103}
104
105/// CORS middleware configuration
106#[derive(Debug, Clone, JsonSchema, Deserialize)]
107pub struct TraceLoggerMiddleware {
108    /// toggle enable
109    pub enable: bool,
110    pub level: LogLevel,
111}
112
113#[derive(Debug, Default, Clone, JsonSchema, Deserialize)]
114pub enum LogLevel {
115    /// The "trace" level.
116    #[serde(rename = "trace")]
117    Trace,
118    /// The "debug" level.
119    #[serde(rename = "debug")]
120    Debug,
121    /// The "info" level.
122    #[serde(rename = "info")]
123    #[default]
124    Info,
125    /// The "warn" level.
126    #[serde(rename = "warn")]
127    Warn,
128    /// The "error" level.
129    #[serde(rename = "error")]
130    Error,
131}
132
133#[allow(clippy::from_over_into)]
134impl Into<Level> for LogLevel {
135    fn into(self) -> Level {
136        match self {
137            Self::Trace => Level::TRACE,
138            Self::Debug => Level::DEBUG,
139            Self::Info => Level::INFO,
140            Self::Warn => Level::WARN,
141            Self::Error => Level::ERROR,
142        }
143    }
144}
145
146/// CORS middleware configuration
147#[derive(Debug, Clone, JsonSchema, Deserialize)]
148pub struct CorsMiddleware {
149    /// toggle enable
150    pub enable: bool,
151    /// Allow origins
152    pub allow_origins: Option<Vec<String>>,
153    /// Allow headers
154    pub allow_headers: Option<Vec<String>>,
155    /// Allow methods
156    pub allow_methods: Option<Vec<String>>,
157    /// Max age
158    pub max_age: Option<u64>,
159}
160
161/// Timeout middleware configuration
162#[derive(Debug, Clone, JsonSchema, Deserialize)]
163pub struct TimeoutRequestMiddleware {
164    /// toggle enable
165    pub enable: bool,
166    /// Timeout request in milliseconds
167    pub timeout: u64,
168}
169
170/// Limit payload size middleware configuration
171#[derive(Debug, Clone, JsonSchema, Deserialize)]
172pub struct LimitPayloadMiddleware {
173    /// toggle enable
174    pub enable: bool,
175    /// Body limit. for example: 5mb
176    pub body_limit: String,
177}
178
179/// A generic middleware configuration that can be enabled or
180/// disabled.
181#[derive(Debug, PartialEq, Clone, JsonSchema, Deserialize)]
182pub struct EnableMiddleware {
183    /// toggle enable
184    pub enable: bool,
185}
186
187fn default_assets_path() -> String {
188    "static".to_string()
189}
190
191fn default_assets_uri() -> String {
192    "/static".to_string()
193}
194
195fn default_fallback() -> String {
196    "index.html".to_string()
197}
198
199/// SocketIO configuration
200#[cfg(feature = "socket_io")]
201#[derive(Debug, Configurable, JsonSchema, Deserialize)]
202#[config_prefix = "socket_io"]
203pub struct SocketIOConfig {
204    #[serde(default = "default_namespace")]
205    pub default_namespace: String,
206}
207
208#[cfg(feature = "socket_io")]
209fn default_namespace() -> String {
210    "/".to_string()
211}