1use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use tracing::warn;
8
9use crate::error::{Error, Result};
10use crate::extract::{FromRequest, RequestContext};
11
12#[derive(Default)]
20pub struct StateMap {
21 entries: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
22}
23
24impl StateMap {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn insert<S: Send + Sync + 'static>(&mut self, value: S) {
32 if self.entries.contains_key(&TypeId::of::<S>()) {
33 warn!(
34 target: "tork",
35 "state value of type `{}` is being silently replaced",
36 std::any::type_name::<S>(),
37 );
38 }
39 self.entries.insert(TypeId::of::<S>(), Arc::new(value));
40 }
41
42 pub fn get<S: Send + Sync + 'static>(&self) -> Option<Arc<S>> {
44 self.entries
45 .get(&TypeId::of::<S>())
46 .and_then(|entry| entry.clone().downcast::<S>().ok())
47 }
48
49 pub fn contains<S: Send + Sync + 'static>(&self) -> bool {
51 self.entries.contains_key(&TypeId::of::<S>())
52 }
53
54 pub fn remove<S: Send + Sync + 'static>(&mut self) {
56 self.entries.remove(&TypeId::of::<S>());
57 }
58}
59
60pub type AppStateRef = Arc<StateMap>;
62
63pub struct State<S>(pub S);
74
75impl<S> FromRequest for State<S>
76where
77 S: Clone + Send + Sync + 'static,
78{
79 fn from_request(
80 ctx: &RequestContext,
81 ) -> impl std::future::Future<Output = Result<Self>> + Send {
82 let resolved = match ctx.state().get::<S>() {
83 Some(value) => Ok(State((*value).clone())),
84 None => Err(Error::internal(format!(
85 "application state `{}` was not configured",
86 std::any::type_name::<S>()
87 ))),
88 };
89 async move { resolved }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[derive(Clone)]
98 struct Config {
99 name: String,
100 }
101
102 #[test]
103 fn insert_and_get_by_type() {
104 let mut map = StateMap::new();
105 map.insert(Config {
106 name: "tork".to_owned(),
107 });
108
109 let config = map.get::<Config>().expect("config should be present");
110 assert_eq!(config.name, "tork");
111 assert!(map.get::<u32>().is_none());
112 assert!(map.contains::<Config>());
113 }
114}