1use serde::de::{Error, MapAccess, SeqAccess, Visitor};
2use serde::{de, Deserialize, Deserializer, Serialize};
3use std::env::VarError;
4use std::fmt;
5use std::fmt::{Display, Formatter};
6use std::fs::File;
7use std::io::BufReader;
8use std::marker::PhantomData;
9use std::net::SocketAddr;
10use std::path::PathBuf;
11use std::str::FromStr;
12
13pub fn get_config() -> Config {
45 env_logger::init();
46 let config_path = std::env::var("CONFIG").unwrap_or_else(|e| match e {
47 VarError::NotPresent => String::from("ENV"),
48 e => panic!("{}", e),
49 });
50
51 match ConfigParsingStrategy::from_str(&config_path).unwrap() {
52 ConfigParsingStrategy::Env => from_env().unwrap(),
53 ConfigParsingStrategy::Yaml(r) => from_yaml(&r).unwrap(),
54 ConfigParsingStrategy::Json(r) => from_json(&r).unwrap(),
55 }
56}
57
58fn from_env() -> Result<Config, std::env::VarError> {
59 Ok(Config {
60 addr: from_env_optional("ADDR")?.map(|a| SocketAddr::from_str(&a).expect("invalid addr")),
61 tls: tls_from_env()?,
62 secure: secure_from_env()?,
63 queue: queue_from_env()?,
64 service_discovery: service_discovery_from_env()?,
65 websocket: websocket_from_env()?,
66 garbage_collector: garbage_collector_from_env()?,
67 })
68}
69
70fn tls_from_env() -> Result<Option<Tls>, std::env::VarError> {
71 let private_key = from_env_optional("TLS_PRIVATE_KEY")?;
72 let cert = from_env_optional("TLS_CERT")?;
73
74 Ok(private_key
75 .and_then(|p| Some((p, cert?)))
76 .map(|(p, c)| Tls {
77 private_key: p,
78 cert: c,
79 }))
80}
81
82fn secure_from_env() -> Result<Option<Secure>, std::env::VarError> {
83 let jwt_token_expiration = from_env_optional("SECURE_JWT_EXPIRATION_TIME")?
84 .map(|e| e.parse().expect("invalid jwt expiration time"))
85 .unwrap_or_else(default_jwt_token_expiration);
86 let service_token = from_env_optional("SECURE_SERVICE_TOKEN")?.map(|st| Secure {
87 service_token: st,
88 jwt_token_expiration,
89 });
90 Ok(service_token)
91}
92
93fn garbage_collector_from_env() -> Result<GarbageCollector, std::env::VarError> {
94 let gb = from_env_optional("GARBAGE_COLLECTOR_INTERVAL")?
95 .map(|interval| GarbageCollector {
96 interval: interval.parse().expect("invalid garbage interval"),
97 })
98 .unwrap_or_default();
99 Ok(gb)
100}
101
102fn websocket_from_env() -> Result<WebSocket, std::env::VarError> {
103 let mut websocket = WebSocket::default();
104 if let Some(key) = from_env_optional("WEBSOCKET_KEY")? {
105 websocket.key = key;
106 }
107 if let Some(version) = from_env_optional("WEBSOCKET_VERSION")? {
108 websocket.version = version;
109 }
110 Ok(websocket)
111}
112
113fn queue_from_env() -> Result<Queue, std::env::VarError> {
114 let default: DefaultQueues = from_env_optional("QUEUE_DEFAULT")?
115 .map(|d| {
116 d.split(';')
117 .filter(|s| !s.is_empty())
118 .map(String::from)
119 .collect()
120 })
121 .unwrap_or_default();
122
123 let db_path = from_env_optional("QUEUE_DB_PATH")?.map(PathBuf::from);
124 let max_key_updates = from_env_optional("QUEUE_MAX_KEY_UPDATES")?
125 .map(|mku| mku.parse().expect("invalid max keys updates value"));
126 Ok(Queue {
127 default,
128 db_path,
129 max_key_updates,
130 })
131}
132
133fn service_discovery_from_env() -> Result<Option<ServiceDiscovery>, std::env::VarError> {
134 let service_discovery_type =
135 from_env_optional("SERVICE_DISCOVERY_TYPE")?.unwrap_or_else(|| String::from("API"));
136 let default_shards: Option<Shards> = from_env_optional("SERVICE_DISCOVERY_DEFAULT_SHARDS")?
137 .map(|d| {
138 d.split(';')
139 .filter(|s| !s.is_empty())
140 .map(String::from)
141 .collect()
142 });
143
144 let service_discovery = match service_discovery_type.as_str() {
145 "API" => ServiceDiscovery::Api {
146 default: default_shards,
147 },
148 "ETCD" => ServiceDiscovery::Etcd {
149 default: default_shards,
150 hosts: std::env::var("SERVICE_DISCOVERY_HOSTS")
151 .expect("empty service discovery hosts")
152 .split(';')
153 .filter(|s| !s.is_empty())
154 .map(String::from)
155 .collect(),
156 prefix: from_env_optional("SERVICE_DISCOVERY_PREFIX")?
157 .unwrap_or_else(default_sd_prefix),
158 instance_opts: instance_opts_from_env()?,
159 },
160 _ => panic!("Invalid service discovery type"),
161 };
162
163 Ok(Some(service_discovery))
164}
165
166fn instance_opts_from_env() -> Result<Option<ServiceDiscoveryInstanceOptions>, std::env::VarError> {
167 let instance_addr = from_env_optional("SERVICE_DISCOVERY_INSTANCE_ADDR")?;
168 let instance_id = from_env_optional("SERVICE_DISCOVERY_INSTANCE_id")?;
169
170 Ok(instance_addr.map(|ia| ServiceDiscoveryInstanceOptions {
171 instance_addr: ia,
172 instance_id,
173 }))
174}
175
176fn from_env_optional(env_var: &str) -> Result<Option<String>, std::env::VarError> {
177 std::env::var(env_var).map(Some).or_else(|e| match e {
178 VarError::NotPresent => Ok(None),
179 e => Err(e),
180 })
181}
182
183fn from_yaml(path: &str) -> serde_yaml::Result<Config> {
184 let reader = match File::open(path) {
185 Ok(r) => BufReader::new(r),
186 Err(e) => return Err(serde_yaml::Error::custom(e)),
187 };
188 serde_yaml::from_reader(reader)
189}
190
191fn from_json(path: &str) -> serde_json::Result<Config> {
192 let reader = match File::open(path) {
193 Ok(r) => BufReader::new(r),
194 Err(e) => return Err(serde_json::Error::custom(e)),
195 };
196 serde_json::from_reader(reader)
197}
198
199enum ConfigParsingStrategy<T> {
200 Env,
201 Yaml(T),
202 Json(T),
203}
204
205impl FromStr for ConfigParsingStrategy<String> {
206 type Err = &'static str;
207
208 fn from_str(s: &str) -> Result<Self, Self::Err> {
209 match s {
210 "ENV" => Ok(Self::Env),
211 s if s.ends_with(".yaml") => Ok(Self::Yaml(String::from(s))),
212 s if s.ends_with(".json") => Ok(Self::Json(String::from(s))),
213 _ => Err("invalid config type"),
214 }
215 }
216}
217
218#[derive(Serialize, Deserialize, Clone, Debug)]
219pub struct Config {
220 pub addr: Option<SocketAddr>,
221 pub tls: Option<Tls>,
222 pub secure: Option<Secure>,
223 pub queue: Queue,
224 pub service_discovery: Option<ServiceDiscovery>,
225 #[serde(default)]
226 pub websocket: WebSocket,
227 #[serde(default)]
228 pub garbage_collector: GarbageCollector,
229}
230
231#[derive(Serialize, Deserialize, Clone, Debug)]
232pub struct Tls {
233 pub private_key: String,
234 pub cert: String,
235}
236
237#[derive(Serialize, Clone, Debug)]
238pub struct Secure {
239 pub service_token: SecureToken,
240 #[serde(default = "default_jwt_token_expiration")]
241 pub jwt_token_expiration: u64,
242}
243
244#[derive(Deserialize)]
245#[serde(remote = "Secure")]
246struct SecureDef {
247 pub service_token: SecureToken,
248 #[serde(default = "default_jwt_token_expiration")]
249 pub jwt_token_expiration: u64,
250}
251
252pub fn default_jwt_token_expiration() -> u64 {
253 60
254}
255
256pub type SecureToken = String;
257
258impl From<SecureToken> for Secure {
259 fn from(service_token: SecureToken) -> Self {
260 Self {
261 service_token,
262 jwt_token_expiration: default_jwt_token_expiration(),
263 }
264 }
265}
266
267#[derive(Serialize, Deserialize, Clone, Debug)]
268pub struct WebSocket {
269 pub key: String,
270 #[serde(default = "default_websocket_v")]
271 pub version: String,
272}
273
274fn default_websocket_v() -> String {
275 "13".into()
276}
277
278impl Default for WebSocket {
279 fn default() -> Self {
280 Self {
281 key: "SGVsbG8sIHdvcmxkIQ==".into(),
282 version: default_websocket_v(),
283 }
284 }
285}
286
287#[derive(Serialize, Deserialize, Clone, Debug)]
288pub struct Queue {
289 #[serde(default)]
290 pub default: DefaultQueues,
291 pub db_path: Option<PathBuf>,
292 pub max_key_updates: Option<usize>,
293}
294
295pub type DefaultQueues = Vec<String>;
296
297#[derive(Serialize, Clone, Debug)]
298pub struct GarbageCollector {
299 pub interval: u64,
300}
301
302#[derive(Deserialize)]
303#[serde(remote = "GarbageCollector")]
304struct GarbageCollectorDef {
305 pub interval: u64,
306}
307
308impl From<u64> for GarbageCollector {
309 fn from(interval: u64) -> Self {
310 Self { interval }
311 }
312}
313
314impl Default for GarbageCollector {
315 fn default() -> Self {
316 Self::from(60)
317 }
318}
319
320pub type Shards = Vec<String>;
321
322#[derive(Serialize, Clone, Debug)]
323#[serde(tag = "type", rename_all = "lowercase")]
324pub enum ServiceDiscovery {
325 Api {
326 default: Option<Shards>,
327 },
328 Etcd {
329 default: Option<Shards>,
330 hosts: ServiceDiscoveryHosts,
331 #[serde(default = "default_sd_prefix")]
332 prefix: String,
333 instance_opts: Option<ServiceDiscoveryInstanceOptions>,
334 },
335}
336
337fn default_sd_prefix() -> String {
338 "sonya".into()
339}
340
341#[derive(Deserialize)]
342#[serde(tag = "type", rename_all = "lowercase")]
343#[serde(remote = "ServiceDiscovery")]
344enum ServiceDiscoveryDef {
345 Api {
346 default: Option<Shards>,
347 },
348 Etcd {
349 default: Option<Shards>,
350 hosts: ServiceDiscoveryHosts,
351 #[serde(default = "default_sd_prefix")]
352 prefix: String,
353 instance_opts: Option<ServiceDiscoveryInstanceOptions>,
354 },
355}
356
357#[derive(Serialize, Deserialize, Clone, Debug)]
358pub struct ServiceDiscoveryInstanceOptions {
359 pub instance_id: Option<String>,
360 pub instance_addr: String,
361}
362
363impl Display for ServiceDiscovery {
364 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
365 write!(
366 f,
367 "{}",
368 match self {
369 ServiceDiscovery::Api { .. } => "api",
370 ServiceDiscovery::Etcd { .. } => "etcd",
371 }
372 )
373 }
374}
375
376impl From<Shards> for ServiceDiscovery {
377 fn from(default: Shards) -> Self {
378 Self::Api {
379 default: Some(default),
380 }
381 }
382}
383
384pub type ServiceDiscoveryHosts = Vec<String>;
385
386struct StringOrStruct<T>(PhantomData<T>);
387struct VecOrStruct<T>(PhantomData<T>);
388struct U64OrStruct<T>(PhantomData<T>);
389
390#[macro_export]
391macro_rules! string_or_struct_impl {
392 ($struct_name: ident, $struct_name_remote: ident) => {
393 impl<'de> Visitor<'de> for StringOrStruct<$struct_name> {
394 type Value = $struct_name;
395
396 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
397 write!(
398 formatter,
399 "string or struct {} expected",
400 std::any::type_name::<Self::Value>()
401 )
402 }
403
404 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
405 where
406 E: de::Error,
407 {
408 Ok(Self::Value::from(value.to_owned()))
409 }
410
411 fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
412 where
413 M: MapAccess<'de>,
414 {
415 $struct_name_remote::deserialize(de::value::MapAccessDeserializer::new(map))
416 }
417 }
418
419 impl<'de> Deserialize<'de> for $struct_name {
420 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
421 where
422 D: Deserializer<'de>,
423 {
424 deserializer.deserialize_any(StringOrStruct::<Self>(PhantomData))
425 }
426 }
427 };
428}
429
430#[macro_export]
431macro_rules! vec_or_struct_impl {
432 ($struct_name: ident, $struct_name_remote: ident) => {
433 impl<'de> Visitor<'de> for VecOrStruct<$struct_name> {
434 type Value = $struct_name;
435
436 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
437 write!(
438 formatter,
439 "list of strings or struct {} expected",
440 std::any::type_name::<Self::Value>()
441 )
442 }
443
444 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
445 where
446 E: de::Error,
447 {
448 Ok(Self::Value::from(vec![value.to_owned()]))
449 }
450
451 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
452 where
453 A: SeqAccess<'de>,
454 {
455 let mut vec = Vec::new();
456
457 while let Some(elem) = seq.next_element::<String>()? {
458 vec.push(elem);
459 }
460
461 Ok(Self::Value::from(vec))
462 }
463
464 fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
465 where
466 M: MapAccess<'de>,
467 {
468 $struct_name_remote::deserialize(de::value::MapAccessDeserializer::new(map))
469 }
470 }
471
472 impl<'de> Deserialize<'de> for $struct_name {
473 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
474 where
475 D: Deserializer<'de>,
476 {
477 deserializer.deserialize_any(VecOrStruct::<Self>(PhantomData))
478 }
479 }
480 };
481}
482
483#[macro_export]
484macro_rules! u64_or_struct {
485 ($struct_name: ident, $struct_name_remote: ident) => {
486 impl<'de> Visitor<'de> for U64OrStruct<$struct_name> {
487 type Value = $struct_name;
488
489 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
490 write!(
491 formatter,
492 "list of strings or struct {} expected",
493 std::any::type_name::<Self::Value>()
494 )
495 }
496
497 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
498 where
499 E: de::Error,
500 {
501 Ok(Self::Value::from(value))
502 }
503
504 fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
505 where
506 M: MapAccess<'de>,
507 {
508 $struct_name_remote::deserialize(de::value::MapAccessDeserializer::new(map))
509 }
510 }
511
512 impl<'de> Deserialize<'de> for $struct_name {
513 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
514 where
515 D: Deserializer<'de>,
516 {
517 deserializer.deserialize_any(U64OrStruct::<Self>(PhantomData))
518 }
519 }
520 };
521}
522
523string_or_struct_impl!(Secure, SecureDef);
524vec_or_struct_impl!(ServiceDiscovery, ServiceDiscoveryDef);
525u64_or_struct!(GarbageCollector, GarbageCollectorDef);