1use conjure_object::Any;
19use pin_project::{pin_project, pinned_drop};
20use serde::Serialize;
21use std::cell::RefCell;
22use std::collections::{hash_map, HashMap};
23use std::future::Future;
24use std::mem;
25use std::pin::Pin;
26use std::sync::{Arc, OnceLock};
27use std::task::{Context, Poll};
28
29static EMPTY: OnceLock<Map> = OnceLock::new();
30
31thread_local! {
32 static MDC: RefCell<Snapshot> = RefCell::new(Snapshot::new());
33}
34
35pub fn insert_safe<T>(key: &'static str, value: T) -> Option<Any>
41where
42 T: Serialize,
43{
44 MDC.with(|v| v.borrow_mut().safe_mut().insert(key, value))
45}
46
47pub fn insert_unsafe<T>(key: &'static str, value: T) -> Option<Any>
53where
54 T: Serialize,
55{
56 MDC.with(|v| v.borrow_mut().unsafe_mut().insert(key, value))
57}
58
59pub fn remove_safe(key: &str) -> Option<Any> {
61 MDC.with(|v| v.borrow_mut().safe_mut().remove(key))
62}
63
64pub fn remove_unsafe(key: &str) -> Option<Any> {
66 MDC.with(|v| v.borrow_mut().unsafe_mut().remove(key))
67}
68
69pub fn snapshot() -> Snapshot {
73 MDC.with(|v| v.borrow().clone())
74}
75
76pub fn clear() {
78 MDC.with(|v| {
79 let mut mdc = v.borrow_mut();
80 mdc.safe_mut().clear();
81 mdc.unsafe_mut().clear();
82 });
83}
84
85pub fn set(snapshot: Snapshot) -> Snapshot {
87 MDC.with(|v| mem::replace(&mut *v.borrow_mut(), snapshot))
88}
89
90pub fn swap(snapshot: &mut Snapshot) {
92 MDC.with(|v| mem::swap(&mut *v.borrow_mut(), snapshot));
93}
94
95pub fn bind<F>(future: F) -> Bind<F> {
100 Bind {
101 future: Some(future),
102 snapshot: snapshot(),
103 }
104}
105
106pub fn scope() -> Scope {
108 Scope { old: snapshot() }
109}
110
111#[derive(Clone, Debug, PartialEq, Eq)]
113pub struct Map {
114 map: Arc<HashMap<&'static str, Any>>,
115}
116
117impl Default for Map {
118 #[inline]
119 fn default() -> Self {
120 EMPTY
121 .get_or_init(|| Map {
122 map: Arc::new(HashMap::new()),
123 })
124 .clone()
125 }
126}
127
128impl Map {
129 #[inline]
131 pub fn new() -> Self {
132 Map::default()
133 }
134
135 #[inline]
137 pub fn clear(&mut self) {
138 match Arc::get_mut(&mut self.map) {
140 Some(map) => map.clear(),
141 None => *self = Map::new(),
142 }
143 }
144
145 #[inline]
147 pub fn len(&self) -> usize {
148 self.map.len()
149 }
150
151 #[inline]
153 pub fn is_empty(&self) -> bool {
154 self.map.is_empty()
155 }
156
157 #[inline]
159 pub fn get(&self, key: &str) -> Option<&Any> {
160 self.map.get(key)
161 }
162
163 #[inline]
165 pub fn contains_key(&self, key: &str) -> bool {
166 self.map.contains_key(key)
167 }
168
169 #[inline]
175 pub fn insert<V>(&mut self, key: &'static str, value: V) -> Option<Any>
176 where
177 V: Serialize,
178 {
179 let value = Any::new(value).expect("value failed to serialize");
180 Arc::make_mut(&mut self.map).insert(key, value)
181 }
182
183 #[inline]
185 pub fn remove(&mut self, key: &str) -> Option<Any> {
186 Arc::make_mut(&mut self.map).remove(key)
187 }
188
189 #[inline]
191 pub fn iter(&self) -> Iter<'_> {
192 Iter {
193 it: self.map.iter(),
194 }
195 }
196}
197
198impl<'a> IntoIterator for &'a Map {
199 type Item = (&'static str, &'a Any);
200
201 type IntoIter = Iter<'a>;
202
203 #[inline]
204 fn into_iter(self) -> Self::IntoIter {
205 self.iter()
206 }
207}
208
209pub struct Iter<'a> {
211 it: hash_map::Iter<'a, &'static str, Any>,
212}
213
214impl<'a> Iterator for Iter<'a> {
215 type Item = (&'static str, &'a Any);
216
217 #[inline]
218 fn next(&mut self) -> Option<Self::Item> {
219 self.it.next().map(|(k, v)| (*k, v))
220 }
221
222 #[inline]
223 fn size_hint(&self) -> (usize, Option<usize>) {
224 self.it.size_hint()
225 }
226}
227
228impl ExactSizeIterator for Iter<'_> {
229 #[inline]
230 fn len(&self) -> usize {
231 self.it.len()
232 }
233}
234
235#[derive(Clone, Default, Debug, PartialEq, Eq)]
237pub struct Snapshot {
238 safe: Map,
239 unsafe_: Map,
240}
241
242impl Snapshot {
243 #[inline]
245 pub fn new() -> Self {
246 Snapshot::default()
247 }
248
249 #[inline]
251 pub fn safe(&self) -> &Map {
252 &self.safe
253 }
254
255 #[inline]
257 pub fn safe_mut(&mut self) -> &mut Map {
258 &mut self.safe
259 }
260
261 #[inline]
263 pub fn unsafe_(&self) -> &Map {
264 &self.unsafe_
265 }
266
267 #[inline]
269 pub fn unsafe_mut(&mut self) -> &mut Map {
270 &mut self.unsafe_
271 }
272}
273
274pub struct Scope {
276 old: Snapshot,
277}
278
279impl Drop for Scope {
280 fn drop(&mut self) {
281 swap(&mut self.old);
282 }
283}
284
285#[pin_project(PinnedDrop)]
287pub struct Bind<F> {
288 #[pin]
289 future: Option<F>,
290 snapshot: Snapshot,
291}
292
293#[pinned_drop]
294impl<F> PinnedDrop for Bind<F> {
295 fn drop(self: Pin<&mut Self>) {
296 let mut this = self.project();
297 let _guard = scope_with(this.snapshot);
298 this.future.set(None);
299 }
300}
301
302impl<F> Future for Bind<F>
303where
304 F: Future,
305{
306 type Output = F::Output;
307
308 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309 let this = self.project();
310 let _guard = scope_with(this.snapshot);
311 this.future.as_pin_mut().unwrap().poll(cx)
312 }
313}
314
315fn scope_with(snapshot: &mut Snapshot) -> ScopeWith<'_> {
317 swap(snapshot);
318 ScopeWith { snapshot }
319}
320
321struct ScopeWith<'a> {
322 snapshot: &'a mut Snapshot,
323}
324
325impl Drop for ScopeWith<'_> {
326 fn drop(&mut self) {
327 swap(self.snapshot);
328 }
329}
330
331#[cfg(test)]
332mod test {
333 use conjure_object::Any;
334
335 use crate::mdc;
336
337 #[test]
338 fn scope() {
339 mdc::clear();
340
341 mdc::insert_safe("foo", "bar");
342 let guard = mdc::scope();
343 mdc::insert_safe("foo", "baz");
344 assert_eq!(
345 mdc::snapshot().safe().get("foo").unwrap(),
346 &Any::new("baz").unwrap(),
347 );
348
349 drop(guard);
350 assert_eq!(
351 mdc::snapshot().safe().get("foo").unwrap(),
352 &Any::new("bar").unwrap(),
353 );
354 }
355
356 #[test]
357 fn bind() {
358 mdc::clear();
359
360 mdc::insert_safe("foo", "bar");
361 futures_executor::block_on(mdc::bind(async {
362 mdc::insert_safe("foo", "baz");
363 assert_eq!(
364 mdc::snapshot().safe().get("foo").unwrap(),
365 &Any::new("baz").unwrap(),
366 );
367 }));
368
369 assert_eq!(
370 mdc::snapshot().safe().get("foo").unwrap(),
371 &Any::new("bar").unwrap(),
372 );
373 }
374}