Skip to main content

rustio_core/
context.rs

1//! Typed per-request storage keyed by `TypeId`.
2//!
3//! Middleware writes into the context (`req.ctx_mut().insert(User { .. })`);
4//! handlers and later middleware read by type (`req.ctx().get::<User>()`).
5//! Stored values only need to be `Send + Sync + 'static` — no `Clone` bound.
6
7use std::any::{Any, TypeId};
8use std::collections::HashMap;
9use std::fmt;
10
11pub struct Context {
12    map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
13}
14
15impl Context {
16    pub fn new() -> Self {
17        Self {
18            map: HashMap::new(),
19        }
20    }
21
22    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) -> Option<T> {
23        self.map
24            .insert(TypeId::of::<T>(), Box::new(value))
25            .and_then(|prev| prev.downcast::<T>().ok().map(|b| *b))
26    }
27
28    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
29        self.map
30            .get(&TypeId::of::<T>())
31            .and_then(|boxed| boxed.downcast_ref::<T>())
32    }
33
34    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
35        self.map
36            .get_mut(&TypeId::of::<T>())
37            .and_then(|boxed| boxed.downcast_mut::<T>())
38    }
39
40    pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
41        self.map
42            .remove(&TypeId::of::<T>())
43            .and_then(|boxed| boxed.downcast::<T>().ok().map(|b| *b))
44    }
45
46    pub fn contains<T: 'static>(&self) -> bool {
47        self.map.contains_key(&TypeId::of::<T>())
48    }
49
50    pub fn len(&self) -> usize {
51        self.map.len()
52    }
53
54    pub fn is_empty(&self) -> bool {
55        self.map.is_empty()
56    }
57}
58
59impl Default for Context {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl fmt::Debug for Context {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        f.debug_struct("Context")
68            .field("entries", &self.map.len())
69            .finish()
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[derive(Debug, PartialEq)]
78    struct Tag(u32);
79
80    struct NotClone(String);
81
82    #[test]
83    fn insert_and_get_by_type() {
84        let mut ctx = Context::new();
85        assert!(ctx.insert(Tag(42)).is_none());
86        assert_eq!(ctx.get::<Tag>(), Some(&Tag(42)));
87    }
88
89    #[test]
90    fn different_types_coexist() {
91        let mut ctx = Context::new();
92        ctx.insert(Tag(1));
93        ctx.insert(String::from("hello"));
94        assert_eq!(ctx.get::<Tag>(), Some(&Tag(1)));
95        assert_eq!(ctx.get::<String>().map(String::as_str), Some("hello"));
96    }
97
98    #[test]
99    fn insert_replaces_same_type_and_returns_prev() {
100        let mut ctx = Context::new();
101        ctx.insert(Tag(1));
102        let prev = ctx.insert(Tag(2));
103        assert_eq!(prev, Some(Tag(1)));
104        assert_eq!(ctx.get::<Tag>(), Some(&Tag(2)));
105    }
106
107    #[test]
108    fn get_mut_permits_mutation_in_place() {
109        let mut ctx = Context::new();
110        ctx.insert(Tag(0));
111        ctx.get_mut::<Tag>().unwrap().0 = 7;
112        assert_eq!(ctx.get::<Tag>(), Some(&Tag(7)));
113    }
114
115    #[test]
116    fn remove_returns_owned_value_and_clears() {
117        let mut ctx = Context::new();
118        ctx.insert(Tag(9));
119        assert_eq!(ctx.remove::<Tag>(), Some(Tag(9)));
120        assert!(!ctx.contains::<Tag>());
121        assert!(ctx.get::<Tag>().is_none());
122    }
123
124    #[test]
125    fn absent_type_returns_none() {
126        let ctx = Context::new();
127        assert!(ctx.get::<Tag>().is_none());
128        assert!(!ctx.contains::<Tag>());
129    }
130
131    #[test]
132    fn non_clone_types_are_allowed() {
133        let mut ctx = Context::new();
134        ctx.insert(NotClone("present".into()));
135        assert_eq!(ctx.get::<NotClone>().unwrap().0, "present");
136    }
137}