umbral_core/db/
route_context.rs1use std::future::Future;
6use std::sync::Arc;
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10pub struct TenantKey(String);
11
12impl TenantKey {
13 pub fn new(s: impl Into<String>) -> Self {
14 TenantKey(s.into())
15 }
16 pub fn as_str(&self) -> &str {
17 &self.0
18 }
19}
20
21#[derive(Clone, Default)]
24pub struct RouteContext {
25 tenant: Option<TenantKey>,
26 extensions: http::Extensions,
27}
28
29impl RouteContext {
30 pub fn new() -> Self {
31 Self::default()
32 }
33 pub fn with_tenant(mut self, tenant: TenantKey) -> Self {
34 self.tenant = Some(tenant);
35 self
36 }
37 pub fn tenant(&self) -> Option<&TenantKey> {
38 self.tenant.as_ref()
39 }
40 pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
42 self.extensions.insert(value);
43 }
44 pub fn get<T: Clone + Send + Sync + 'static>(&self) -> Option<&T> {
46 self.extensions.get::<T>()
47 }
48}
49
50tokio::task_local! {
51 static ROUTE_CONTEXT: Arc<RouteContext>;
52}
53
54pub fn current() -> Arc<RouteContext> {
59 ROUTE_CONTEXT
60 .try_with(|c| c.clone())
61 .unwrap_or_else(|_| Arc::new(RouteContext::default()))
62}
63
64pub async fn scope<F: Future>(ctx: RouteContext, fut: F) -> F::Output {
67 ROUTE_CONTEXT.scope(Arc::new(ctx), fut).await
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[tokio::test]
75 async fn current_is_default_when_unset() {
76 assert!(current().tenant().is_none());
78 }
79
80 #[tokio::test]
81 async fn scope_sets_and_restores_context() {
82 let ctx = RouteContext::new().with_tenant(TenantKey::new("acme"));
83 scope(ctx, async {
84 assert_eq!(current().tenant().unwrap().as_str(), "acme");
85 })
86 .await;
87 assert!(current().tenant().is_none());
89 }
90
91 #[tokio::test]
92 async fn spawned_task_does_not_inherit_context() {
93 let ctx = RouteContext::new().with_tenant(TenantKey::new("acme"));
94 scope(ctx, async {
95 let handle = tokio::spawn(async { current().tenant().cloned() });
99 assert!(handle.await.unwrap().is_none());
100 })
101 .await;
102 }
103
104 #[tokio::test]
105 async fn extensions_store_typed_values() {
106 #[derive(Clone, PartialEq, Debug)]
107 struct Region(&'static str);
108 let mut ctx = RouteContext::new();
109 ctx.insert(Region("eu"));
110 scope(ctx, async {
111 assert_eq!(current().get::<Region>(), Some(&Region("eu")));
112 })
113 .await;
114 }
115}