1use super::config::RetryStrategy;
9use super::error::HttpError;
10use serde::Serialize;
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use std::time::Duration;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum Method {
19 GET,
20 POST,
21 PUT,
22 DELETE,
23 PATCH,
24 HEAD,
25 OPTIONS,
26}
27
28impl std::fmt::Display for Method {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 Method::GET => write!(f, "GET"),
32 Method::POST => write!(f, "POST"),
33 Method::PUT => write!(f, "PUT"),
34 Method::DELETE => write!(f, "DELETE"),
35 Method::PATCH => write!(f, "PATCH"),
36 Method::HEAD => write!(f, "HEAD"),
37 Method::OPTIONS => write!(f, "OPTIONS"),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub(crate) enum RequestBody {
45 None,
47 Json(Vec<u8>),
49 Form(String),
51 Bytes(Vec<u8>),
53 Text(String),
55}
56
57#[derive(Debug)]
73pub struct RequestBuilder {
74 pub(crate) method: Method,
75 pub(crate) url: String,
76 pub(crate) headers: HashMap<String, String>,
77 pub(crate) query: Vec<(String, String)>,
78 pub(crate) body: RequestBody,
79 pub(crate) timeout: Option<Duration>,
80 pub(crate) max_retries: Option<u32>,
81 pub(crate) retry_strategy: Option<RetryStrategy>,
82}
83
84impl RequestBuilder {
85 pub fn new(method: Method, url: impl Into<String>) -> Self {
87 Self {
88 method,
89 url: url.into(),
90 headers: HashMap::new(),
91 query: Vec::new(),
92 body: RequestBody::None,
93 timeout: None,
94 max_retries: None,
95 retry_strategy: None,
96 }
97 }
98
99 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
101 self.headers.insert(name.into(), value.into());
102 self
103 }
104
105 pub fn headers(mut self, headers: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>) -> Self {
107 for (k, v) in headers {
108 self.headers.insert(k.into(), v.into());
109 }
110 self
111 }
112
113 pub fn bearer_auth(self, token: impl AsRef<str>) -> Self {
115 self.header("Authorization", format!("Bearer {}", token.as_ref()))
116 }
117
118 pub fn basic_auth(self, username: impl AsRef<str>, password: Option<&str>) -> Self {
120 use base64::Engine;
121 let credentials = match password {
122 Some(p) => format!("{}:{}", username.as_ref(), p),
123 None => username.as_ref().to_string(),
124 };
125 let encoded = base64::engine::general_purpose::STANDARD.encode(credentials);
126 self.header("Authorization", format!("Basic {}", encoded))
127 }
128
129 pub fn query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
131 self.query.push((key.into(), value.into()));
132 self
133 }
134
135 pub fn queries(mut self, params: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>) -> Self {
137 for (k, v) in params {
138 self.query.push((k.into(), v.into()));
139 }
140 self
141 }
142
143 pub fn json<T: Serialize>(mut self, body: &T) -> Result<Self, HttpError> {
145 let bytes = serde_json::to_vec(body)
146 .map_err(|e| HttpError::JsonSerialize(e.to_string()))?;
147 self.body = RequestBody::Json(bytes);
148 self.headers.insert("Content-Type".into(), "application/json".into());
149 Ok(self)
150 }
151
152 pub fn form<T: Serialize>(mut self, body: &T) -> Result<Self, HttpError> {
154 let encoded = serde_urlencoded::to_string(body)
155 .map_err(|e| HttpError::FormEncode(e.to_string()))?;
156 self.body = RequestBody::Form(encoded);
157 self.headers
158 .insert("Content-Type".into(), "application/x-www-form-urlencoded".into());
159 Ok(self)
160 }
161
162 pub fn body(mut self, bytes: impl Into<Vec<u8>>) -> Self {
164 self.body = RequestBody::Bytes(bytes.into());
165 self
166 }
167
168 pub fn text(mut self, text: impl Into<String>) -> Self {
170 self.body = RequestBody::Text(text.into());
171 self.headers.insert("Content-Type".into(), "text/plain; charset=utf-8".into());
172 self
173 }
174
175 pub fn timeout(mut self, timeout: Duration) -> Self {
177 self.timeout = Some(timeout);
178 self
179 }
180
181 pub fn max_retries(mut self, retries: u32) -> Self {
183 self.max_retries = Some(retries);
184 self
185 }
186
187 pub fn no_retry(mut self) -> Self {
189 self.max_retries = Some(0);
190 self.retry_strategy = Some(RetryStrategy::None);
191 self
192 }
193
194 pub fn with_cancellation(self, token: CancellationToken) -> CancellableRequest {
196 CancellableRequest::new(self, token)
197 }
198
199 pub(crate) fn build_url(&self) -> Result<String, HttpError> {
201 if self.query.is_empty() {
202 return Ok(self.url.clone());
203 }
204
205 let query_string = self
206 .query
207 .iter()
208 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
209 .collect::<Vec<_>>()
210 .join("&");
211
212 let separator = if self.url.contains('?') { "&" } else { "?" };
213 Ok(format!("{}{}{}", self.url, separator, query_string))
214 }
215}
216
217#[derive(Clone)]
225pub struct CancellationToken {
226 cancelled: Arc<AtomicBool>,
227}
228
229impl CancellationToken {
230 pub fn is_cancelled(&self) -> bool {
232 self.cancelled.load(Ordering::Acquire)
233 }
234
235 pub async fn cancelled(&self) {
237 while !self.is_cancelled() {
238 tokio::time::sleep(Duration::from_millis(10)).await;
239 }
240 }
241}
242
243pub struct CancellationTrigger {
247 cancelled: Arc<AtomicBool>,
248}
249
250impl CancellationTrigger {
251 pub fn cancel(&self) {
253 self.cancelled.store(true, Ordering::Release);
254 }
255}
256
257pub fn cancellation_pair() -> (CancellationToken, CancellationTrigger) {
261 let cancelled = Arc::new(AtomicBool::new(false));
262 (
263 CancellationToken { cancelled: cancelled.clone() },
264 CancellationTrigger { cancelled },
265 )
266}
267
268pub struct CancellableRequest {
272 pub(crate) builder: RequestBuilder,
273 pub(crate) cancel_token: CancellationToken,
274}
275
276impl CancellableRequest {
277 pub fn new(builder: RequestBuilder, token: CancellationToken) -> Self {
279 Self {
280 builder,
281 cancel_token: token,
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_request_builder_basic() {
292 let req = RequestBuilder::new(Method::GET, "https://example.com")
293 .header("Accept", "application/json")
294 .query("page", "1")
295 .query("limit", "10");
296
297 assert_eq!(req.method, Method::GET);
298 assert_eq!(req.headers.get("Accept"), Some(&"application/json".into()));
299 assert_eq!(req.query.len(), 2);
300 }
301
302 #[test]
303 fn test_build_url_with_query() {
304 let req = RequestBuilder::new(Method::GET, "https://example.com/api")
305 .query("name", "test")
306 .query("value", "hello world");
307
308 let url = req.build_url().unwrap();
309 assert!(url.contains("name=test"));
310 assert!(url.contains("value=hello%20world"));
311 }
312
313 #[test]
314 fn test_bearer_auth() {
315 let req = RequestBuilder::new(Method::GET, "https://example.com")
316 .bearer_auth("my_token");
317
318 assert_eq!(
319 req.headers.get("Authorization"),
320 Some(&"Bearer my_token".into())
321 );
322 }
323
324 #[test]
325 fn test_basic_auth() {
326 let req = RequestBuilder::new(Method::GET, "https://example.com")
327 .basic_auth("user", Some("pass"));
328
329 let auth = req.headers.get("Authorization").unwrap();
330 assert!(auth.starts_with("Basic "));
331 }
332
333 #[test]
334 fn test_json_body() {
335 #[derive(Serialize)]
336 struct Data {
337 name: String,
338 }
339
340 let req = RequestBuilder::new(Method::POST, "https://example.com")
341 .json(&Data { name: "test".into() })
342 .unwrap();
343
344 assert_eq!(
345 req.headers.get("Content-Type"),
346 Some(&"application/json".into())
347 );
348 matches!(req.body, RequestBody::Json(_));
349 }
350
351 #[test]
352 fn test_method_display() {
353 assert_eq!(Method::GET.to_string(), "GET");
354 assert_eq!(Method::POST.to_string(), "POST");
355 assert_eq!(Method::PUT.to_string(), "PUT");
356 assert_eq!(Method::DELETE.to_string(), "DELETE");
357 }
358
359 #[test]
360 fn test_cancellation_token() {
361 let (token, trigger) = cancellation_pair();
362 assert!(!token.is_cancelled());
363 trigger.cancel();
364 assert!(token.is_cancelled());
365 }
366}