1use schemars::JsonSchema;
2use serde::Deserialize;
3use spring::config::Configurable;
4use std::net::{IpAddr, Ipv4Addr};
5use tracing::Level;
6
7spring::submit_config_schema!("web", WebConfig);
8
9#[cfg(feature = "socket_io")]
10spring::submit_config_schema!("socket_io", SocketIOConfig);
11
12#[derive(Debug, Configurable, JsonSchema, Deserialize)]
14#[config_prefix = "web"]
15pub struct WebConfig {
16    #[serde(flatten)]
17    pub(crate) server: ServerConfig,
18    #[cfg(feature = "openapi")]
19    pub(crate) openapi: OpenApiConfig,
20    pub(crate) middlewares: Option<Middlewares>,
21}
22
23#[derive(Debug, Clone, JsonSchema, Deserialize)]
24pub struct ServerConfig {
25    #[serde(default = "default_binding")]
26    pub(crate) binding: IpAddr,
27    #[serde(default = "default_port")]
28    pub(crate) port: u16,
29    #[serde(default)]
30    pub(crate) connect_info: bool,
31    #[serde(default)]
32    pub(crate) graceful: bool,
33}
34
35#[cfg(feature = "openapi")]
36#[derive(Debug, Clone, JsonSchema, Deserialize)]
37pub struct OpenApiConfig {
38    #[serde(default = "default_doc_prefix")]
39    pub(crate) doc_prefix: String,
40    #[serde(default)]
41    pub(crate) info: aide::openapi::Info,
42}
43
44fn default_binding() -> IpAddr {
45    IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
46}
47
48fn default_port() -> u16 {
49    8080
50}
51
52#[cfg(feature = "openapi")]
53fn default_doc_prefix() -> String {
54    "/docs".into()
55}
56
57#[derive(Debug, Clone, JsonSchema, Deserialize)]
59pub struct Middlewares {
60    pub compression: Option<EnableMiddleware>,
62    pub limit_payload: Option<LimitPayloadMiddleware>,
64    pub logger: Option<TraceLoggerMiddleware>,
67    pub catch_panic: Option<EnableMiddleware>,
69    pub timeout_request: Option<TimeoutRequestMiddleware>,
71    pub cors: Option<CorsMiddleware>,
73    #[serde(rename = "static")]
75    pub static_assets: Option<StaticAssetsMiddleware>,
76}
77
78#[derive(Debug, Clone, JsonSchema, Deserialize)]
80pub struct StaticAssetsMiddleware {
81    pub enable: bool,
83    #[serde(default = "bool::default")]
85    pub must_exist: bool,
86    #[serde(default = "default_fallback")]
89    pub fallback: String,
90    #[serde(default = "bool::default")]
92    pub precompressed: bool,
93    #[serde(default = "default_assets_uri")]
95    pub uri: String,
96    #[serde(default = "default_assets_path")]
98    pub path: String,
99}
100
101#[derive(Debug, Clone, JsonSchema, Deserialize)]
103pub struct TraceLoggerMiddleware {
104    pub enable: bool,
106    pub level: LogLevel,
107}
108
109#[derive(Debug, Default, Clone, JsonSchema, Deserialize)]
110pub enum LogLevel {
111    #[serde(rename = "trace")]
113    Trace,
114    #[serde(rename = "debug")]
116    Debug,
117    #[serde(rename = "info")]
119    #[default]
120    Info,
121    #[serde(rename = "warn")]
123    Warn,
124    #[serde(rename = "error")]
126    Error,
127}
128
129#[allow(clippy::from_over_into)]
130impl Into<Level> for LogLevel {
131    fn into(self) -> Level {
132        match self {
133            Self::Trace => Level::TRACE,
134            Self::Debug => Level::DEBUG,
135            Self::Info => Level::INFO,
136            Self::Warn => Level::WARN,
137            Self::Error => Level::ERROR,
138        }
139    }
140}
141
142#[derive(Debug, Clone, JsonSchema, Deserialize)]
144pub struct CorsMiddleware {
145    pub enable: bool,
147    pub allow_origins: Option<Vec<String>>,
149    pub allow_headers: Option<Vec<String>>,
151    pub allow_methods: Option<Vec<String>>,
153    pub max_age: Option<u64>,
155}
156
157#[derive(Debug, Clone, JsonSchema, Deserialize)]
159pub struct TimeoutRequestMiddleware {
160    pub enable: bool,
162    pub timeout: u64,
164}
165
166#[derive(Debug, Clone, JsonSchema, Deserialize)]
168pub struct LimitPayloadMiddleware {
169    pub enable: bool,
171    pub body_limit: String,
173}
174
175#[derive(Debug, PartialEq, Clone, JsonSchema, Deserialize)]
178pub struct EnableMiddleware {
179    pub enable: bool,
181}
182
183fn default_assets_path() -> String {
184    "static".to_string()
185}
186
187fn default_assets_uri() -> String {
188    "/static".to_string()
189}
190
191fn default_fallback() -> String {
192    "index.html".to_string()
193}
194
195#[cfg(feature = "socket_io")]
197#[derive(Debug, Configurable, JsonSchema, Deserialize)]
198#[config_prefix = "socket_io"]
199pub struct SocketIOConfig {
200    #[serde(default = "default_namespace")]
201    pub default_namespace: String,
202}
203
204#[cfg(feature = "socket_io")]
205fn default_namespace() -> String {
206    "/".to_string()
207}