1use crate::error::{Error, Result};
2use crate::pagination::{PaginatedRequest, PaginationStream};
3use crate::request::{Request, RequestData};
4use base64::{engine::general_purpose::STANDARD, Engine};
5use futures::prelude::*;
6use hyper::header::{HeaderMap, HeaderName, HeaderValue};
7use hyper::{
8 body::{to_bytes, Body},
9 client::HttpConnector,
10 http::request::Builder,
11 Client as HyperClient,
12};
13use hyper_tls::HttpsConnector;
14use log::debug;
15use secrecy::Secret;
16use std::collections::HashMap;
17use std::convert::TryFrom;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tower::Service;
21
22#[derive(Clone)]
23enum Authorization {
24 Bearer(Secret<String>),
25 Basic(String, Option<Secret<String>>),
26 Query(HashMap<String, Secret<String>>),
27 Header(HeaderMap<HeaderValue>),
28}
29
30#[derive(Clone)]
35pub struct Client {
36 inner: HyperClient<HttpsConnector<HttpConnector>, Body>,
37 base_url: String,
38 default_headers: HeaderMap<HeaderValue>,
39 auth: Option<Authorization>,
40}
41
42impl<R: Request + 'static> Service<R> for Client {
43 type Response = R::Response;
44 type Error = Error;
45 type Future = Pin<Box<dyn Send + Future<Output = Result<Self::Response>>>>;
46
47 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
48 Poll::Ready(Ok(()))
49 }
50
51 fn call(&mut self, request: R) -> Self::Future {
52 let this = self.clone();
53 Box::pin(async move { this.send(request).await })
54 }
55}
56
57impl Client {
58 pub fn new<S: ToString>(base_url: S) -> Self {
60 let connector = HttpsConnector::new();
61 let client = HyperClient::builder().build(connector);
62
63 Self::from_hyper(client, base_url)
64 }
65
66 pub fn from_hyper<S: ToString>(
68 inner: HyperClient<HttpsConnector<HttpConnector>>,
69 base_url: S,
70 ) -> Self {
71 Self {
72 inner,
73 base_url: base_url.to_string(),
74 default_headers: HeaderMap::default(),
75 auth: None,
76 }
77 }
78
79 pub fn bearer_auth<S: ToString>(mut self, token: S) -> Self {
81 self.auth = Some(Authorization::Bearer(Secret::new(token.to_string())));
82 self
83 }
84
85 pub fn basic_auth<T: Into<Option<S>>, S: ToString>(mut self, user: S, pass: T) -> Self {
87 self.auth = Some(Authorization::Basic(
88 user.to_string(),
89 pass.into().map(|x| Secret::new(x.to_string())),
90 ));
91 self
92 }
93
94 pub fn query_auth<S: ToString>(mut self, pairs: Vec<(S, S)>) -> Self {
96 let pairs = pairs
97 .into_iter()
98 .map(|(k, v)| (k.to_string(), Secret::new(v.to_string())))
99 .collect();
100 self.auth = Some(Authorization::Query(pairs));
101 self
102 }
103
104 pub fn header_auth<S: ToString>(mut self, pairs: Vec<(S, S)>) -> Self {
106 let mut map = HeaderMap::new();
107 for (k, v) in pairs {
108 let k = k.to_string();
109 let v = v.to_string();
110 let mut header_value = HeaderValue::from_str(&v).expect("Failed to create HeaderValue");
111 header_value.set_sensitive(true);
112 map.insert(
113 HeaderName::try_from(&k).expect("Failed to create HeaderName"),
114 header_value,
115 );
116 }
117 self.auth = Some(Authorization::Header(map));
118 self
119 }
120
121 pub fn default_headers(mut self, default_headers: HeaderMap<HeaderValue>) -> Self {
122 self.default_headers = default_headers;
123 self
124 }
125
126 fn send_raw<R>(&self, req: hyper::Request<Body>) -> impl std::future::Future<Output = Result<R>>
127 where
128 R: for<'de> serde::Deserialize<'de>,
129 {
130 debug!("Sending request: {:?}", req);
131 self.inner
132 .request(req)
133 .map_err(From::from)
134 .and_then(|mut res| async move {
135 let status = res.status();
136 let body = to_bytes(res.body_mut()).await?;
137 if status.is_success() {
138 serde_json::from_slice(&body).map_err(From::from)
139 } else if status.is_client_error() {
140 Err(Error::ClientError(status, String::from_utf8(body.into())?))
141 } else {
142 Err(Error::ServerError(status, String::from_utf8(body.into())?))
143 }
144 })
145 }
146
147 fn format_request<R: Request>(&self, request: &R) -> Result<hyper::Request<Body>> {
148 let endpoint = request.endpoint();
149 let endpoint = endpoint.trim_matches('/');
150 let url = format!("{}/{}", self.base_url, endpoint);
151
152 let mut headers = self.default_headers.clone();
153 headers.extend(request.headers());
154 let mut req = Builder::new().uri(&url).method(R::METHOD);
155 for header in headers {
156 req = req.header(header.0.expect("Always has a header name"), header.1);
157 }
158
159 req = {
160 use secrecy::ExposeSecret;
161 match &self.auth {
162 None => req,
163 Some(Authorization::Bearer(token)) => {
164 let mut header_value =
165 HeaderValue::from_str(&format!("Bearer {}", token.expose_secret()))
166 .expect("Failed to create HeaderValue");
167 header_value.set_sensitive(true);
168 req.header("authorization", header_value)
169 }
170 Some(Authorization::Basic(user, pass)) => {
171 let creds = format!(
172 "{}:{}",
173 user,
174 pass.as_ref()
175 .map(|x| x.expose_secret())
176 .unwrap_or(&String::new())
177 );
178 let encoded = STANDARD.encode(creds);
179 let mut header_value = HeaderValue::from_str(&format!("Basic {}", encoded))
180 .expect("Failed to create HeaderValue");
181 header_value.set_sensitive(true);
182 req.header("authorization", header_value)
183 }
184 Some(Authorization::Query(pairs)) => {
185 let pairs: HashMap<_, _> =
186 pairs.iter().map(|(k, v)| (k, v.expose_secret())).collect();
187 let query = serde_qs::to_string(&pairs)?;
188 let url = if url.contains('?') {
189 format!("{}&{}", url, query)
190 } else {
191 format!("{}?{}", url, query)
192 };
193 req.uri(url)
194 }
195 Some(Authorization::Header(pairs)) => {
196 for (k, v) in pairs {
197 req = req.header(k, v);
198 }
199 req
200 }
201 }
202 };
203
204 let body = match request.data() {
205 RequestData::Empty => Body::empty(),
206 RequestData::Form(data) => {
207 req = req
208 .header("content-type", "application/x-www-form-urlencoded")
209 .uri(url);
210 let body = serde_urlencoded::to_string(data)?;
211 Body::from(body)
212 }
213 RequestData::Json(data) => {
214 req = req.header("content-type", "application/json").uri(url);
215 let bytes = serde_json::to_vec(&data)?;
216 Body::from(bytes)
217 }
218 RequestData::Query(data) => {
219 let query = serde_qs::to_string(data)?;
220 let url = if url.contains('?') {
221 format!("{}&{}", url, query)
222 } else {
223 format!("{}?{}", url, query)
224 };
225 req = req.uri(url);
226 Body::empty()
227 }
228 };
229
230 req.body(body).map_err(From::from)
231 }
232
233 pub async fn send<R: Request>(&self, request: R) -> Result<R::Response> {
235 let req = self.format_request(&request)?;
236 self.send_raw(req).await
237 }
238}
239
240pub trait ServiceExt<R, T>: Service<R> {
241 fn paginate(self, request: R) -> PaginationStream<Self, T, R>
242 where
243 T: Clone,
244 R: Request<Response = <Self as Service<R>>::Response>,
245 R: PaginatedRequest<PaginationData = T>,
246 Self: Sized,
247 {
248 PaginationStream::new(self, request)
249 }
250}
251
252impl<P, T, Request> ServiceExt<Request, P> for T
253where
254 T: ?Sized + Service<Request>,
255 Request: PaginatedRequest<PaginationData = P>,
256{
257}