1use std::any::{Any, TypeId};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use crate::extract::RequestContext;
15
16mod client;
17mod cookie;
18mod recorder;
19mod request;
20mod response;
21mod sse;
22mod websocket;
23
24pub use client::{TestClient, TestClientBuilder};
25pub use recorder::{LogRecord, LogRecorder};
26pub use request::{TestMultipartBuilder, TestRequestBuilder};
27pub use response::TestResponse;
28pub use sse::{TestSseEvent, TestSseStream};
29pub use websocket::{TestWebSocket, TestWebSocketBuilder};
30
31type OverrideFactory = Arc<dyn Fn() -> Box<dyn Any + Send> + Send + Sync>;
33
34#[derive(Default, Clone)]
40pub(crate) struct TestOverrides {
41 factories: HashMap<TypeId, OverrideFactory>,
42}
43
44impl TestOverrides {
45 #[allow(dead_code)]
48 pub(crate) fn insert<T, F>(&mut self, factory: F)
49 where
50 T: Send + 'static,
51 F: Fn() -> T + Send + Sync + 'static,
52 {
53 self.factories
54 .insert(TypeId::of::<T>(), Arc::new(move || Box::new(factory())));
55 }
56
57 fn produce<T: 'static>(&self) -> Option<T> {
59 let factory = self.factories.get(&TypeId::of::<T>())?;
60 factory().downcast::<T>().ok().map(|boxed| *boxed)
61 }
62
63 #[allow(dead_code)]
65 pub(crate) fn is_empty(&self) -> bool {
66 self.factories.is_empty()
67 }
68}
69
70#[doc(hidden)]
76pub fn __take_override<T: 'static>(ctx: &RequestContext) -> Option<T> {
77 ctx.state().get::<TestOverrides>()?.produce::<T>()
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::body::box_body;
84 use crate::extract::PathParams;
85 use crate::state::StateMap;
86 use bytes::Bytes;
87 use http_body_util::Full;
88
89 fn context_with_overrides(overrides: TestOverrides) -> RequestContext {
90 let mut state = StateMap::new();
91 state.insert(overrides);
92 let head = http::Request::new(()).into_parts().0;
93 RequestContext::new(
94 head,
95 PathParams::new(),
96 Arc::new(state),
97 box_body(Full::new(Bytes::new())),
98 )
99 }
100
101 #[test]
102 fn override_registry_reports_empty_and_produces_fresh_values() {
103 let mut overrides = TestOverrides::default();
104 assert!(overrides.is_empty());
105
106 overrides.insert::<String, _>(|| "hello".to_owned());
107 assert!(!overrides.is_empty());
108 assert_eq!(overrides.produce::<String>().as_deref(), Some("hello"));
109 }
110
111 #[test]
112 fn take_override_reads_registered_override() {
113 let mut overrides = TestOverrides::default();
114 overrides.insert::<usize, _>(|| 7usize);
115
116 let ctx = context_with_overrides(overrides);
117 assert_eq!(__take_override::<usize>(&ctx), Some(7));
118 assert_eq!(__take_override::<String>(&ctx), None);
119 }
120}