Skip to main content

tork_core/
state.rs

1//! Type-erased application state and the [`State`] extractor.
2
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use tracing::warn;
8
9use crate::error::{Error, Result};
10use crate::extract::{FromRequest, RequestContext};
11
12/// A type-erased, thread-safe container for application state values.
13///
14/// Each value is stored under its [`TypeId`], so a state value is retrieved by
15/// its type. This lets routers and handlers stay free of any state type
16/// parameter: the [`App`](crate::App) is not generic over its state, which is
17/// what allows router modules to be built without knowing the concrete state
18/// type.
19#[derive(Default)]
20pub struct StateMap {
21    entries: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
22}
23
24impl StateMap {
25    /// Creates an empty state map.
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Inserts a state value, replacing any existing value of the same type.
31    pub fn insert<S: Send + Sync + 'static>(&mut self, value: S) {
32        if self.entries.contains_key(&TypeId::of::<S>()) {
33            warn!(
34                target: "tork",
35                "state value of type `{}` is being silently replaced",
36                std::any::type_name::<S>(),
37            );
38        }
39        self.entries.insert(TypeId::of::<S>(), Arc::new(value));
40    }
41
42    /// Returns a shared handle to the stored value of type `S`, if present.
43    pub fn get<S: Send + Sync + 'static>(&self) -> Option<Arc<S>> {
44        self.entries
45            .get(&TypeId::of::<S>())
46            .and_then(|entry| entry.clone().downcast::<S>().ok())
47    }
48
49    /// Returns `true` if a value of type `S` is stored.
50    pub fn contains<S: Send + Sync + 'static>(&self) -> bool {
51        self.entries.contains_key(&TypeId::of::<S>())
52    }
53
54    /// Removes the stored value of type `S`, if present.
55    pub fn remove<S: Send + Sync + 'static>(&mut self) {
56        self.entries.remove(&TypeId::of::<S>());
57    }
58}
59
60/// A shared, reference-counted handle to the application state map.
61pub type AppStateRef = Arc<StateMap>;
62
63/// Extractor that yields a clone of an application state value of type `S`.
64///
65/// The wrapped value is cloned out of the shared state on each request, so `S`
66/// should be cheap to clone (for example, hold connection pools or other handles
67/// behind `Arc`).
68///
69/// # Errors
70///
71/// Resolving fails with an internal error if no value of type `S` was registered
72/// with [`App::state`](crate::App::state).
73pub struct State<S>(pub S);
74
75impl<S> FromRequest for State<S>
76where
77    S: Clone + Send + Sync + 'static,
78{
79    fn from_request(
80        ctx: &RequestContext,
81    ) -> impl std::future::Future<Output = Result<Self>> + Send {
82        let resolved = match ctx.state().get::<S>() {
83            Some(value) => Ok(State((*value).clone())),
84            None => Err(Error::internal(format!(
85                "application state `{}` was not configured",
86                std::any::type_name::<S>()
87            ))),
88        };
89        async move { resolved }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[derive(Clone)]
98    struct Config {
99        name: String,
100    }
101
102    #[test]
103    fn insert_and_get_by_type() {
104        let mut map = StateMap::new();
105        map.insert(Config {
106            name: "tork".to_owned(),
107        });
108
109        let config = map.get::<Config>().expect("config should be present");
110        assert_eq!(config.name, "tork");
111        assert!(map.get::<u32>().is_none());
112        assert!(map.contains::<Config>());
113    }
114}