use crate::{Conn, IntoUrl, Pool, USER_AGENT};
use std::{convert::TryInto, fmt::Debug, sync::Arc};
use trillium_http::{
transport::BoxedTransport, HeaderName, HeaderValues, Headers, KnownHeaderName, Method,
ReceivedBodyState,
};
use trillium_server_common::{Connector, ObjectSafeConnector, Url};
use url::Origin;
#[derive(Clone, Debug)]
pub struct Client {
config: Arc<dyn ObjectSafeConnector>,
pool: Option<Pool<Origin, BoxedTransport>>,
base: Option<Arc<Url>>,
default_headers: Arc<Headers>,
}
macro_rules! method {
($fn_name:ident, $method:ident) => {
method!(
$fn_name,
$method,
concat!(
"Builds a new client conn with the ",
stringify!($fn_name),
" http method and the provided url.
```
# use trillium_testing::prelude::*;
# use trillium_smol::ClientConfig;
# use trillium_client::Client;
let client = Client::new(ClientConfig::default());
let conn = client.",
stringify!($fn_name),
"(\"http://localhost:8080/some/route\"); //<-
assert_eq!(conn.method(), Method::",
stringify!($method),
");
assert_eq!(conn.url().to_string(), \"http://localhost:8080/some/route\");
```
"
)
);
};
($fn_name:ident, $method:ident, $doc_comment:expr) => {
#[doc = $doc_comment]
pub fn $fn_name(&self, url: impl IntoUrl) -> Conn {
self.build_conn(Method::$method, url)
}
};
}
pub(crate) fn default_request_headers() -> Headers {
Headers::new()
.with_inserted_header(KnownHeaderName::UserAgent, USER_AGENT)
.with_inserted_header(KnownHeaderName::Accept, "*/*")
}
impl Client {
pub fn new(config: impl Connector) -> Self {
Self {
config: config.arced(),
pool: None,
base: None,
default_headers: Arc::new(default_request_headers()),
}
}
pub fn without_default_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
self.default_headers_mut().remove(name);
self
}
pub fn with_default_header(
mut self,
name: impl Into<HeaderName<'static>>,
value: impl Into<HeaderValues>,
) -> Self {
self.default_headers_mut().insert(name, value);
self
}
pub fn default_headers(&self) -> &Headers {
&self.default_headers
}
pub fn default_headers_mut(&mut self) -> &mut Headers {
Arc::make_mut(&mut self.default_headers)
}
pub fn with_default_pool(mut self) -> Self {
self.pool = Some(Pool::default());
self
}
pub fn build_conn<M>(&self, method: M, url: impl IntoUrl) -> Conn
where
M: TryInto<Method>,
<M as TryInto<Method>>::Error: Debug,
{
Conn {
url: self.build_url(url).unwrap(),
method: method.try_into().unwrap(),
request_headers: Headers::clone(&self.default_headers),
response_headers: Headers::new(),
transport: None,
status: None,
request_body: None,
pool: self.pool.clone(),
buffer: Vec::with_capacity(128).into(),
response_body_state: ReceivedBodyState::Start,
config: Arc::clone(&self.config),
headers_finalized: false,
}
}
pub fn connector(&self) -> &Arc<dyn ObjectSafeConnector> {
&self.config
}
pub fn clean_up_pool(&self) {
if let Some(pool) = &self.pool {
pool.cleanup();
}
}
pub fn with_base(mut self, base: impl IntoUrl) -> Self {
self.set_base(base).unwrap();
self
}
pub fn base(&self) -> Option<&Url> {
self.base.as_deref()
}
pub fn build_url(&self, url: impl IntoUrl) -> crate::Result<Url> {
url.into_url(self.base())
}
pub fn set_base(&mut self, base: impl IntoUrl) -> crate::Result<()> {
let mut base = base.into_url(None)?;
if !base.path().ends_with('/') {
log::warn!("appending a trailing / to {base}");
base.set_path(&format!("{}/", base.path()));
}
self.base = Some(Arc::new(base));
Ok(())
}
method!(get, Get);
method!(post, Post);
method!(put, Put);
method!(delete, Delete);
method!(patch, Patch);
}
impl<T: Connector> From<T> for Client {
fn from(connector: T) -> Self {
Self::new(connector)
}
}