1use core::marker::PhantomData;
59use core::future::Future;
60use core::fmt;
61use std::path::Path;
62
63use crate::header;
64
65pub mod config;
66pub mod request;
67pub mod response;
68
69pub use request::Request;
70pub use response::Response;
71
72pub struct Client<C=config::DefaultCfg> where C: config::Config + 'static {
74 inner: hyper::Client<C::Connector>,
75 _config: PhantomData<C>
76}
77
78impl Default for Client {
79 fn default() -> Self {
81 Client::<config::DefaultCfg>::new()
82 }
83}
84
85impl<C: config::Config> fmt::Debug for Client<C> {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 write!(f, "Yukikaze {{ HyperClient={:?} }}", self.inner)
88 }
89}
90
91pub type RequestResult = Result<response::Response, hyper::Error>;
93
94use tokio::io::{AsyncRead, AsyncWrite};
95
96impl<C: config::Config> Client<C> where <C::Connector as hyper::service::Service<hyper::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
97 <C::Connector as hyper::service::Service<hyper::Uri>>::Future: Send + Unpin,
98 <C::Connector as hyper::service::Service<hyper::Uri>>::Response: AsyncRead + AsyncWrite + hyper::client::connect::Connection + Unpin + Send
99{
100 pub fn new() -> Client<C> {
104 let inner = C::config_hyper(&mut hyper::Client::builder()).build(C::Connector::default());
105
106 Self {
107 inner,
108 _config: PhantomData
109 }
110 }
111
112 fn apply_headers(request: &mut request::Request) {
113 C::default_headers(request);
114
115 #[cfg(feature = "compu")]
116 {
117 const DEFAULT_COMPRESS: &'static str = "br, gzip, deflate";
118
119 if C::decompress() {
120 let headers = request.headers_mut();
121 if !headers.contains_key(header::ACCEPT_ENCODING) && headers.contains_key(header::RANGE) {
122 headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_static(DEFAULT_COMPRESS));
123 }
124 }
125 }
126 }
127
128 pub async fn request(&self, mut req: request::Request) -> RequestResult {
130 Self::apply_headers(&mut req);
131
132 #[cfg(feature = "carry_extensions")]
133 let mut extensions = req.extract_extensions();
134
135 let ongoing = self.inner.request(req.into());
136 let ongoing = matsu!(ongoing).map(|res| response::Response::new(res));
137
138 #[cfg(feature = "carry_extensions")]
139 {
140 ongoing.map(move |resp| resp.replace_extensions(&mut extensions))
141 }
142 #[cfg(not(feature = "carry_extensions"))]
143 {
144 ongoing
145 }
146 }
147
148 pub async fn send(&self, mut req: request::Request) -> Result<RequestResult, async_timer::Expired<impl Future<Output=RequestResult>, C::Timer>> {
156 Self::apply_headers(&mut req);
157
158 #[cfg(feature = "carry_extensions")]
159 let mut extensions = req.extract_extensions();
160
161 let ongoing = self.inner.request(req.into());
162 let ongoing = async {
163 let res = matsu!(ongoing);
164 res.map(|resp| response::Response::new(resp))
165 };
166
167 let timeout = C::timeout();
168 match timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
169 #[cfg(not(feature = "carry_extensions"))]
170 true => Ok(matsu!(ongoing)),
171 #[cfg(feature = "carry_extensions")]
172 true => Ok(matsu!(ongoing).map(move |resp| resp.replace_extensions(&mut extensions))),
173 false => {
174 let job = unsafe { async_timer::Timed::<_, C::Timer>::new_unchecked(ongoing, timeout) };
175 #[cfg(not(feature = "carry_extensions"))]
176 {
177 matsu!(job)
178 }
179 #[cfg(feature = "carry_extensions")]
180 {
181 matsu!(job).map(move |res| res.map(move |resp| resp.replace_extensions(&mut extensions)))
182 }
183 }
184 }
185 }
186
187 pub async fn send_redirect(&'static self, req: request::Request) -> Result<RequestResult, async_timer::Expired<impl Future<Output=RequestResult> + 'static, C::Timer>> {
195 let timeout = C::timeout();
196 match timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
197 true => Ok(matsu!(self.redirect_request(req))),
198 false => {
199 let ongoing = self.redirect_request(req);
205 let job = unsafe { async_timer::Timed::<_, C::Timer>::new_unchecked(ongoing, timeout) };
206 matsu!(job)
207 }
208 }
209 }
210
211 pub async fn redirect_request(&self, mut req: request::Request) -> RequestResult {
213 use http::{Method, StatusCode};
214
215 Self::apply_headers(&mut req);
216
217 let mut rem_redirect = C::max_redirect_num();
218
219 let mut method = req.parts.method.clone();
220 let uri = req.parts.uri.clone();
221 let mut headers = req.parts.headers.clone();
222 let mut body = req.body.clone();
223 #[cfg(feature = "carry_extensions")]
224 let mut extensions = req.extract_extensions();
225
226 loop {
227 let ongoing = self.inner.request(req.into());
228 let res = matsu!(ongoing).map(|resp| response::Response::new(resp))?;
229
230 match res.status() {
231 StatusCode::SEE_OTHER => {
232 rem_redirect -= 1;
233 match rem_redirect {
234 #[cfg(feature = "carry_extensions")]
235 0 => return Ok(res.replace_extensions(&mut extensions)),
236 #[cfg(not(feature = "carry_extensions"))]
237 0 => return Ok(res),
238 _ => {
239 body = None;
242 method = Method::GET;
243 }
244 }
245 },
246 StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {
247 rem_redirect -= 1;
248 match rem_redirect {
249 #[cfg(feature = "carry_extensions")]
250 0 => return Ok(res.replace_extensions(&mut extensions)),
251 #[cfg(not(feature = "carry_extensions"))]
252 0 => return Ok(res),
253 _ => (),
254 }
255 }
256 #[cfg(feature = "carry_extensions")]
257 _ => return Ok(res.replace_extensions(&mut extensions)),
258 #[cfg(not(feature = "carry_extensions"))]
259 _ => return Ok(res),
260 }
261
262 let location = match res.headers().get(header::LOCATION).and_then(|loc| loc.to_str().ok()).and_then(|loc| loc.parse::<hyper::Uri>().ok()) {
263 Some(loc) => match loc.scheme().is_some() {
264 true => {
266 if let Some(prev_host) = uri.authority().map(|part| part.host()) {
268 match loc.authority().map(|part| part.host() == prev_host).unwrap_or(false) {
269 true => (),
270 false => {
271 headers.remove("authorization");
272 headers.remove("cookie");
273 headers.remove("cookie2");
274 headers.remove("www-authenticate");
275 }
276 }
277 }
278
279 loc
280 },
281 false => {
283 let current = Path::new(uri.path());
284 let loc = Path::new(loc.path());
285 let loc = current.join(loc);
286 let loc = loc.to_str().expect("Valid UTF-8 path").parse::<hyper::Uri>().expect("Valid URI");
287 let mut loc_parts = loc.into_parts();
288
289 loc_parts.scheme = uri.scheme().cloned();
290 loc_parts.authority = uri.authority().cloned();
291
292 hyper::Uri::from_parts(loc_parts).expect("Create redirect URI")
293 },
294 },
295 #[cfg(feature = "carry_extensions")]
296 None => return Ok(res.replace_extensions(&mut extensions)),
297 #[cfg(not(feature = "carry_extensions"))]
298 None => return Ok(res),
299 };
300
301 let (mut parts, _) = hyper::Request::<()>::new(()).into_parts();
302 parts.method = method.clone();
303 parts.uri = location;
304 parts.headers = headers.clone();
305
306 req = request::Request {
307 parts,
308 body: body.clone()
309 };
310 }
311 }
312}