rustapi_testing/
server.rs1use super::expectation::{Expectation, MockResponse, Times};
2use super::matcher::RequestMatcher;
3use bytes::Bytes;
4use http_body_util::{BodyExt, Full};
5use hyper::service::service_fn;
6use hyper::{Request, Response, StatusCode};
7use hyper_util::rt::TokioIo;
8use std::net::SocketAddr;
9use std::sync::{Arc, Mutex};
10use tokio::net::TcpListener;
11use tokio::sync::oneshot;
12
13type GenericError = Box<dyn std::error::Error + Send + Sync>;
14type Result<T> = std::result::Result<T, GenericError>;
15
16pub struct MockServer {
18 addr: SocketAddr,
19 state: Arc<Mutex<ServerState>>,
20 shutdown_tx: Option<oneshot::Sender<()>>,
21}
22
23struct ServerState {
24 expectations: Vec<Expectation>,
25 unmatched_requests: Vec<RecordedRequest>,
26}
27
28#[derive(Debug, Clone)]
29pub struct RecordedRequest {
30 pub method: http::Method,
31 pub path: String,
32 pub headers: http::HeaderMap,
33 pub body: Bytes,
34}
35
36impl MockServer {
37 pub async fn start() -> Self {
39 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
40 let addr = listener.local_addr().unwrap();
41
42 let state = Arc::new(Mutex::new(ServerState {
43 expectations: Vec::new(),
44 unmatched_requests: Vec::new(),
45 }));
46
47 let state_clone = state.clone();
48 let (shutdown_tx, shutdown_rx) = oneshot::channel();
49
50 tokio::spawn(async move {
51 let mut stop_future = shutdown_rx;
52
53 loop {
54 tokio::select! {
55 res = listener.accept() => {
56 match res {
57 Ok((stream, _)) => {
58 let io = TokioIo::new(stream);
59 let state = state_clone.clone();
60
61 tokio::spawn(async move {
62 if let Err(err) = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
63 .serve_connection(io, service_fn(move |req| handle_request(req, state.clone())))
64 .await
65 {
66 eprintln!("Error serving connection: {:?}", err);
67 }
68 });
69 }
70 Err(e) => eprintln!("Accept error: {}", e),
71 }
72 }
73 _ = &mut stop_future => {
74 break;
75 }
76 }
77 }
78 });
79
80 Self {
81 addr,
82 state,
83 shutdown_tx: Some(shutdown_tx),
84 }
85 }
86
87 pub fn kind_url(&self) -> String {
89 format!("http://{}", self.addr)
90 }
91
92 pub fn base_url(&self) -> String {
94 self.kind_url()
95 }
96
97 pub fn unmatched_requests(&self) -> Vec<RecordedRequest> {
99 let state = self.state.lock().unwrap();
100 state.unmatched_requests.clone()
101 }
102
103 pub fn expect(&self, matcher: RequestMatcher) -> ExpectationBuilder {
105 ExpectationBuilder {
106 server: self.state.clone(),
107 expectation: Some(Expectation::new(matcher)),
108 }
109 }
110
111 pub fn verify(&self) {
113 let state = self.state.lock().unwrap();
114 for exp in &state.expectations {
115 match exp.times {
116 Times::Once => assert_eq!(
117 exp.call_count, 1,
118 "Expectation {:?} expected 1 call, got {}",
119 exp.matcher, exp.call_count
120 ),
121 Times::Exactly(n) => assert_eq!(
122 exp.call_count, n,
123 "Expectation {:?} expected {} calls, got {}",
124 exp.matcher, n, exp.call_count
125 ),
126 Times::AtLeast(n) => assert!(
127 exp.call_count >= n,
128 "Expectation {:?} expected at least {} calls, got {}",
129 exp.matcher,
130 n,
131 exp.call_count
132 ),
133 Times::AtMost(n) => assert!(
134 exp.call_count <= n,
135 "Expectation {:?} expected at most {} calls, got {}",
136 exp.matcher,
137 n,
138 exp.call_count
139 ),
140 Times::Any => {}
141 }
142 }
143 }
144}
145
146impl Drop for MockServer {
147 fn drop(&mut self) {
148 if let Some(tx) = self.shutdown_tx.take() {
149 let _ = tx.send(());
150 }
151 }
152}
153
154pub struct ExpectationBuilder {
155 server: Arc<Mutex<ServerState>>,
156 expectation: Option<Expectation>,
157}
158
159impl ExpectationBuilder {
160 pub fn respond_with(mut self, response: MockResponse) -> Self {
161 if let Some(exp) = self.expectation.as_mut() {
162 exp.response = response;
163 }
164 self
165 }
166
167 pub fn times(mut self, n: usize) -> Self {
168 if let Some(exp) = self.expectation.as_mut() {
169 exp.times = Times::Exactly(n);
170 }
171 self
172 }
173
174 pub fn once(mut self) -> Self {
175 if let Some(exp) = self.expectation.as_mut() {
176 exp.times = Times::Once;
177 }
178 self
179 }
180
181 pub fn at_least_once(mut self) -> Self {
182 if let Some(exp) = self.expectation.as_mut() {
183 exp.times = Times::AtLeast(1);
184 }
185 self
186 }
187
188 pub fn never(mut self) -> Self {
189 if let Some(exp) = self.expectation.as_mut() {
190 exp.times = Times::Exactly(0);
191 }
192 self
193 }
194}
195
196impl Drop for ExpectationBuilder {
197 fn drop(&mut self) {
198 if let Some(exp) = self.expectation.take() {
199 let mut state = self.server.lock().unwrap();
200 state.expectations.push(exp);
201 }
202 }
203}
204
205async fn handle_request(
206 req: Request<hyper::body::Incoming>,
207 state: Arc<Mutex<ServerState>>,
208) -> Result<Response<Full<Bytes>>> {
209 let (parts, body) = req.into_parts();
211 let body_bytes = body.collect().await?.to_bytes();
212
213 let mut state_guard = state.lock().unwrap();
214
215 let matching_idx = state_guard
218 .expectations
219 .iter()
220 .enumerate()
221 .rev()
222 .find(|(_, exp)| {
223 exp.matcher
224 .matches(&parts.method, parts.uri.path(), &parts.headers, &body_bytes)
225 })
226 .map(|(i, _)| i);
227
228 if let Some(idx) = matching_idx {
229 let exp = &mut state_guard.expectations[idx];
230 exp.call_count += 1;
231
232 let resp_def = &exp.response;
233 let mut response = Response::builder().status(resp_def.status);
234
235 for (k, v) in &resp_def.headers {
236 response = response.header(k, v);
237 }
238
239 Ok(response.body(Full::new(resp_def.body.clone()))?)
240 } else {
241 state_guard.unmatched_requests.push(RecordedRequest {
243 method: parts.method,
244 path: parts.uri.path().to_string(),
245 headers: parts.headers,
246 body: body_bytes,
247 });
248
249 Ok(Response::builder()
250 .status(StatusCode::NOT_FOUND)
251 .body(Full::new(Bytes::from("No expectation matched")))?)
252 }
253}