rustbasic_core/
testing.rs1use crate::app::Config;
2use crate::server::AppState;
3use crate::router::{Router, Response};
4use crate::requests::Request;
5use crate::session::Session;
6use crate::session_manager::RustBasicSessionStore;
7use std::sync::Arc;
8use crate::rand::distr::SampleString;
9
10#[derive(Clone)]
11pub struct TestClient {
12 pub state: AppState,
13 pub router: Router<AppState>,
14 pub session_store: RustBasicSessionStore,
15}
16
17impl TestClient {
18 pub async fn new(cfg: Config, router: Router<AppState>) -> Self {
19 let db = crate::database::connect(&cfg).await;
21
22 crate::session::init_sessions(&cfg).await;
24 let session_store = crate::session::setup_session(&cfg).await;
25
26 Self {
27 state: AppState {
28 db,
29 config: Arc::new(cfg),
30 },
31 router,
32 session_store,
33 }
34 }
35
36 pub async fn get(&self, path: &str) -> TestResponse {
37 self.send_request("GET", path, None, None).await
38 }
39
40 pub async fn post(&self, path: &str, body: serde_json::Value) -> TestResponse {
41 self.send_request("POST", path, Some(body), None).await
42 }
43
44 pub async fn put(&self, path: &str, body: serde_json::Value) -> TestResponse {
45 self.send_request("PUT", path, Some(body), None).await
46 }
47
48 pub async fn patch(&self, path: &str, body: serde_json::Value) -> TestResponse {
49 self.send_request("PATCH", path, Some(body), None).await
50 }
51
52 pub async fn delete(&self, path: &str) -> TestResponse {
53 self.send_request("DELETE", path, None, None).await
54 }
55
56 pub async fn send_request(
57 &self,
58 method_str: &str,
59 path: &str,
60 body_json: Option<serde_json::Value>,
61 headers_opt: Option<std::collections::HashMap<String, String>>,
62 ) -> TestResponse {
63 let method = http::Method::from_bytes(method_str.as_bytes()).unwrap();
64 let inputs = body_json.unwrap_or_else(|| serde_json::json!({}));
65 let mut headers = headers_opt.unwrap_or_default();
66
67 let id = crate::rand::distr::Alphanumeric.sample_string(&mut crate::rand::rng(), 40);
68 let session = Session::new(id.clone());
69
70 let token = crate::rand::distr::Alphanumeric.sample_string(&mut crate::rand::rng(), 40);
72 session.set("_token", token.clone());
73
74 if !headers.contains_key("x-csrf-token") {
76 headers.insert("x-csrf-token".to_string(), token);
77 }
78
79 let req = Request {
80 inputs,
81 method: method.clone(),
82 path: path.to_string(),
83 headers,
84 session: session.clone(),
85 state: self.state.clone(),
86 ip_address: "127.0.0.1".to_string(),
87 params: std::collections::HashMap::new(),
88 };
89
90 struct RouteDispatcher {
91 router: Router<AppState>,
92 }
93
94 #[crate::async_trait]
95 impl crate::router::ErasedHandler for RouteDispatcher {
96 async fn call(&self, req: Request) -> Response {
97 let method = req.method.clone();
98 let path = req.path.clone();
99
100 let mut matched_handler = None;
101 let mut matched_params = std::collections::HashMap::new();
102 for route in &self.router.routes {
103 if crate::server::match_path(&route.path, &path) {
104 for (m, h) in &route.handlers {
105 if m == &method {
106 matched_handler = Some(h.clone());
107 matched_params = crate::server::extract_params(&route.path, &path);
108 break;
109 }
110 }
111 }
112 if matched_handler.is_some() {
113 break;
114 }
115 }
116
117 if let Some(handler) = matched_handler {
118 let mut req = req;
119 req.params = matched_params;
120 let mut chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::End(handler));
121 for mw in self.router.middlewares.iter().rev() {
122 chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::Next(mw.clone(), chain));
123 }
124 chain.next(req).await
125 } else {
126 crate::errors::ErrorController::not_found().await
127 }
128 }
129 }
130
131 let dispatcher = std::sync::Arc::new(RouteDispatcher {
132 router: self.router.clone(),
133 });
134
135 let mut chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::End(dispatcher));
136 chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::Next(
137 crate::middleware::from_fn(crate::middleware::security_headers::security_headers_middleware),
138 chain,
139 ));
140 chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::Next(
141 crate::middleware::from_fn(crate::middleware::logging::logging_middleware),
142 chain,
143 ));
144
145 let res = chain.next(req).await;
146 TestResponse { response: res }
147 }
148}
149
150pub struct TestResponse {
151 pub response: Response,
152}
153
154impl TestResponse {
155 pub fn status(&self) -> u16 {
156 self.response.status().as_u16()
157 }
158
159 pub fn text(&self) -> String {
160 String::from_utf8(self.response.body().clone()).unwrap_or_default()
161 }
162
163 pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
164 serde_json::from_slice(self.response.body())
165 }
166
167 pub fn assert_status(&self, code: u16) {
168 assert_eq!(self.status(), code, "Response status code was {}, expected {}", self.status(), code);
169 }
170
171 pub fn assert_see(&self, val: &str) {
172 let txt = self.text();
173 assert!(txt.contains(val), "Response did not contain '{}'. Body: {}", val, txt);
174 }
175
176 pub fn assert_dont_see(&self, val: &str) {
177 let txt = self.text();
178 assert!(!txt.contains(val), "Response contained '{}' when it shouldn't. Body: {}", val, txt);
179 }
180}