tower_api_client/
client.rs

1use crate::error::{Error, Result};
2use crate::pagination::{PaginatedRequest, PaginationStream};
3use crate::request::{Request, RequestData};
4use base64::{engine::general_purpose::STANDARD, Engine};
5use futures::prelude::*;
6use hyper::header::{HeaderMap, HeaderName, HeaderValue};
7use hyper::{
8    body::{to_bytes, Body},
9    client::HttpConnector,
10    http::request::Builder,
11    Client as HyperClient,
12};
13use hyper_tls::HttpsConnector;
14use log::debug;
15use secrecy::Secret;
16use std::collections::HashMap;
17use std::convert::TryFrom;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tower::Service;
21
22#[derive(Clone)]
23enum Authorization {
24    Bearer(Secret<String>),
25    Basic(String, Option<Secret<String>>),
26    Query(HashMap<String, Secret<String>>),
27    Header(HeaderMap<HeaderValue>),
28}
29
30/// The main client used for making requests.
31///
32/// `Client` stores an async Reqwest client as well as the associated
33/// base url and possible authorization details for the REST server.
34#[derive(Clone)]
35pub struct Client {
36    inner: HyperClient<HttpsConnector<HttpConnector>, Body>,
37    base_url: String,
38    default_headers: HeaderMap<HeaderValue>,
39    auth: Option<Authorization>,
40}
41
42impl<R: Request + 'static> Service<R> for Client {
43    type Response = R::Response;
44    type Error = Error;
45    type Future = Pin<Box<dyn Send + Future<Output = Result<Self::Response>>>>;
46
47    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
48        Poll::Ready(Ok(()))
49    }
50
51    fn call(&mut self, request: R) -> Self::Future {
52        let this = self.clone();
53        Box::pin(async move { this.send(request).await })
54    }
55}
56
57impl Client {
58    /// Create a new `Client`.
59    pub fn new<S: ToString>(base_url: S) -> Self {
60        let connector = HttpsConnector::new();
61        let client = HyperClient::builder().build(connector);
62
63        Self::from_hyper(client, base_url)
64    }
65
66    /// Create a new `Client` from an existing Hyper Client.
67    pub fn from_hyper<S: ToString>(
68        inner: HyperClient<HttpsConnector<HttpConnector>>,
69        base_url: S,
70    ) -> Self {
71        Self {
72            inner,
73            base_url: base_url.to_string(),
74            default_headers: HeaderMap::default(),
75            auth: None,
76        }
77    }
78
79    /// Enable bearer authentication for the client
80    pub fn bearer_auth<S: ToString>(mut self, token: S) -> Self {
81        self.auth = Some(Authorization::Bearer(Secret::new(token.to_string())));
82        self
83    }
84
85    /// Enable basic authentication for the client
86    pub fn basic_auth<T: Into<Option<S>>, S: ToString>(mut self, user: S, pass: T) -> Self {
87        self.auth = Some(Authorization::Basic(
88            user.to_string(),
89            pass.into().map(|x| Secret::new(x.to_string())),
90        ));
91        self
92    }
93
94    /// Enable query authentication for the client
95    pub fn query_auth<S: ToString>(mut self, pairs: Vec<(S, S)>) -> Self {
96        let pairs = pairs
97            .into_iter()
98            .map(|(k, v)| (k.to_string(), Secret::new(v.to_string())))
99            .collect();
100        self.auth = Some(Authorization::Query(pairs));
101        self
102    }
103
104    /// Enable custom header authentication for the client
105    pub fn header_auth<S: ToString>(mut self, pairs: Vec<(S, S)>) -> Self {
106        let mut map = HeaderMap::new();
107        for (k, v) in pairs {
108            let k = k.to_string();
109            let v = v.to_string();
110            let mut header_value = HeaderValue::from_str(&v).expect("Failed to create HeaderValue");
111            header_value.set_sensitive(true);
112            map.insert(
113                HeaderName::try_from(&k).expect("Failed to create HeaderName"),
114                header_value,
115            );
116        }
117        self.auth = Some(Authorization::Header(map));
118        self
119    }
120
121    pub fn default_headers(mut self, default_headers: HeaderMap<HeaderValue>) -> Self {
122        self.default_headers = default_headers;
123        self
124    }
125
126    fn send_raw<R>(&self, req: hyper::Request<Body>) -> impl std::future::Future<Output = Result<R>>
127    where
128        R: for<'de> serde::Deserialize<'de>,
129    {
130        debug!("Sending request: {:?}", req);
131        self.inner
132            .request(req)
133            .map_err(From::from)
134            .and_then(|mut res| async move {
135                let status = res.status();
136                let body = to_bytes(res.body_mut()).await?;
137                if status.is_success() {
138                    serde_json::from_slice(&body).map_err(From::from)
139                } else if status.is_client_error() {
140                    Err(Error::ClientError(status, String::from_utf8(body.into())?))
141                } else {
142                    Err(Error::ServerError(status, String::from_utf8(body.into())?))
143                }
144            })
145    }
146
147    fn format_request<R: Request>(&self, request: &R) -> Result<hyper::Request<Body>> {
148        let endpoint = request.endpoint();
149        let endpoint = endpoint.trim_matches('/');
150        let url = format!("{}/{}", self.base_url, endpoint);
151
152        let mut headers = self.default_headers.clone();
153        headers.extend(request.headers());
154        let mut req = Builder::new().uri(&url).method(R::METHOD);
155        for header in headers {
156            req = req.header(header.0.expect("Always has a header name"), header.1);
157        }
158
159        req = {
160            use secrecy::ExposeSecret;
161            match &self.auth {
162                None => req,
163                Some(Authorization::Bearer(token)) => {
164                    let mut header_value =
165                        HeaderValue::from_str(&format!("Bearer {}", token.expose_secret()))
166                            .expect("Failed to create HeaderValue");
167                    header_value.set_sensitive(true);
168                    req.header("authorization", header_value)
169                }
170                Some(Authorization::Basic(user, pass)) => {
171                    let creds = format!(
172                        "{}:{}",
173                        user,
174                        pass.as_ref()
175                            .map(|x| x.expose_secret())
176                            .unwrap_or(&String::new())
177                    );
178                    let encoded = STANDARD.encode(creds);
179                    let mut header_value = HeaderValue::from_str(&format!("Basic {}", encoded))
180                        .expect("Failed to create HeaderValue");
181                    header_value.set_sensitive(true);
182                    req.header("authorization", header_value)
183                }
184                Some(Authorization::Query(pairs)) => {
185                    let pairs: HashMap<_, _> =
186                        pairs.iter().map(|(k, v)| (k, v.expose_secret())).collect();
187                    let query = serde_qs::to_string(&pairs)?;
188                    let url = if url.contains('?') {
189                        format!("{}&{}", url, query)
190                    } else {
191                        format!("{}?{}", url, query)
192                    };
193                    req.uri(url)
194                }
195                Some(Authorization::Header(pairs)) => {
196                    for (k, v) in pairs {
197                        req = req.header(k, v);
198                    }
199                    req
200                }
201            }
202        };
203
204        let body = match request.data() {
205            RequestData::Empty => Body::empty(),
206            RequestData::Form(data) => {
207                req = req
208                    .header("content-type", "application/x-www-form-urlencoded")
209                    .uri(url);
210                let body = serde_urlencoded::to_string(data)?;
211                Body::from(body)
212            }
213            RequestData::Json(data) => {
214                req = req.header("content-type", "application/json").uri(url);
215                let bytes = serde_json::to_vec(&data)?;
216                Body::from(bytes)
217            }
218            RequestData::Query(data) => {
219                let query = serde_qs::to_string(data)?;
220                let url = if url.contains('?') {
221                    format!("{}&{}", url, query)
222                } else {
223                    format!("{}?{}", url, query)
224                };
225                req = req.uri(url);
226                Body::empty()
227            }
228        };
229
230        req.body(body).map_err(From::from)
231    }
232
233    /// Send a single `Request`
234    pub async fn send<R: Request>(&self, request: R) -> Result<R::Response> {
235        let req = self.format_request(&request)?;
236        self.send_raw(req).await
237    }
238}
239
240pub trait ServiceExt<R, T>: Service<R> {
241    fn paginate(self, request: R) -> PaginationStream<Self, T, R>
242    where
243        T: Clone,
244        R: Request<Response = <Self as Service<R>>::Response>,
245        R: PaginatedRequest<PaginationData = T>,
246        Self: Sized,
247    {
248        PaginationStream::new(self, request)
249    }
250}
251
252impl<P, T, Request> ServiceExt<Request, P> for T
253where
254    T: ?Sized + Service<Request>,
255    Request: PaginatedRequest<PaginationData = P>,
256{
257}