rustbasic_core/
testing.rs1use crate::app::Config;
2use crate::server::AppState;
3use crate::router::{Router, Response};
4use crate::requests::Request;
5use crate::session::Session;
6use crate::session_manager::RustBasicSessionStore;
7use std::sync::Arc;
8use crate::rand::distr::SampleString;
9
10#[derive(Clone)]
11pub struct TestClient {
12 pub state: AppState,
13 pub router: Router<AppState>,
14 pub session_store: RustBasicSessionStore,
15}
16
17impl TestClient {
18 pub async fn new(cfg: Config, router: Router<AppState>) -> Self {
19 let mut routes_map = std::collections::HashMap::new();
21 for r in &router.routes {
22 if let Some(ref name) = r.name {
23 routes_map.insert(name.clone(), r.path.clone());
24 }
25 }
26 let _ = crate::router::NAMED_ROUTES.set(routes_map);
27
28 let db = crate::database::connect(&cfg).await;
30
31 crate::session::init_sessions(&cfg).await;
33 let session_store = crate::session::setup_session(&cfg).await;
34
35 Self {
36 state: AppState {
37 db,
38 config: Arc::new(cfg),
39 },
40 router,
41 session_store,
42 }
43 }
44
45 pub async fn get(&self, path: &str) -> TestResponse {
46 self.send_request("GET", path, None, None).await
47 }
48
49 pub async fn post(&self, path: &str, body: serde_json::Value) -> TestResponse {
50 self.send_request("POST", path, Some(body), None).await
51 }
52
53 pub async fn put(&self, path: &str, body: serde_json::Value) -> TestResponse {
54 self.send_request("PUT", path, Some(body), None).await
55 }
56
57 pub async fn patch(&self, path: &str, body: serde_json::Value) -> TestResponse {
58 self.send_request("PATCH", path, Some(body), None).await
59 }
60
61 pub async fn delete(&self, path: &str) -> TestResponse {
62 self.send_request("DELETE", path, None, None).await
63 }
64
65 pub async fn send_request(
66 &self,
67 method_str: &str,
68 path: &str,
69 body_json: Option<serde_json::Value>,
70 headers_opt: Option<std::collections::HashMap<String, String>>,
71 ) -> TestResponse {
72 let method = http::Method::from_bytes(method_str.as_bytes()).unwrap();
73 let inputs = body_json.unwrap_or_else(|| serde_json::json!({}));
74 let mut headers = headers_opt.unwrap_or_default();
75
76 let id = crate::rand::distr::Alphanumeric.sample_string(&mut crate::rand::rng(), 40);
77 let session = Session::new(id.clone());
78
79 let token = crate::rand::distr::Alphanumeric.sample_string(&mut crate::rand::rng(), 40);
81 session.set("_token", token.clone());
82
83 if !headers.contains_key("x-csrf-token") {
85 headers.insert("x-csrf-token".to_string(), token);
86 }
87
88 let req = Request {
89 inputs,
90 method: method.clone(),
91 path: path.to_string(),
92 headers,
93 session: session.clone(),
94 state: self.state.clone(),
95 ip_address: "127.0.0.1".to_string(),
96 params: std::collections::HashMap::new(),
97 };
98
99 struct RouteDispatcher {
100 router: Router<AppState>,
101 }
102
103 #[crate::async_trait]
104 impl crate::router::ErasedHandler for RouteDispatcher {
105 async fn call(&self, req: Request) -> Response {
106 let method = req.method.clone();
107 let path = req.path.clone();
108
109 let mut matched_handler = None;
110 let mut matched_params = std::collections::HashMap::new();
111 for route in &self.router.routes {
112 if crate::server::match_path(&route.path, &path) {
113 for (m, h) in &route.handlers {
114 if m == &method {
115 matched_handler = Some(h.clone());
116 matched_params = crate::server::extract_params(&route.path, &path);
117 break;
118 }
119 }
120 }
121 if matched_handler.is_some() {
122 break;
123 }
124 }
125
126 if let Some(handler) = matched_handler {
127 let mut req = req;
128 req.params = matched_params;
129 let mut chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::End(handler));
130 for mw in self.router.middlewares.iter().rev() {
131 chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::Next(mw.clone(), chain));
132 }
133 chain.next(req).await
134 } else {
135 crate::errors::ErrorController::not_found().await
136 }
137 }
138 }
139
140 let dispatcher = std::sync::Arc::new(RouteDispatcher {
141 router: self.router.clone(),
142 });
143
144 let mut chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::End(dispatcher));
145 chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::Next(
146 crate::middleware::from_fn(crate::middleware::security_headers::security_headers_middleware),
147 chain,
148 ));
149 chain = std::sync::Arc::new(crate::middleware::MiddlewareChain::Next(
150 crate::middleware::from_fn(crate::middleware::logging::logging_middleware),
151 chain,
152 ));
153
154 let res = chain.next(req).await;
155 TestResponse { response: res }
156 }
157}
158
159pub struct TestResponse {
160 pub response: Response,
161}
162
163impl TestResponse {
164 pub fn status(&self) -> u16 {
165 self.response.status().as_u16()
166 }
167
168 pub fn text(&self) -> String {
169 String::from_utf8(self.response.body().clone()).unwrap_or_default()
170 }
171
172 pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
173 serde_json::from_slice(self.response.body())
174 }
175
176 pub fn assert_status(&self, code: u16) {
177 assert_eq!(self.status(), code, "Response status code was {}, expected {}", self.status(), code);
178 }
179
180 pub fn assert_see(&self, val: &str) {
181 let txt = self.text();
182 assert!(txt.contains(val), "Response did not contain '{}'. Body: {}", val, txt);
183 }
184
185 pub fn assert_dont_see(&self, val: &str) {
186 let txt = self.text();
187 assert!(!txt.contains(val), "Response contained '{}' when it shouldn't. Body: {}", val, txt);
188 }
189}