1use crate::{
2 ClientHandler, Conn, IntoUrl, Pool, USER_AGENT, client_handler::ArcedClientHandler,
3 conn::H2Pooled, h3::H3ClientState,
4};
5use std::{any::Any, fmt::Debug, sync::Arc, time::Duration};
6use trillium_http::{
7 HeaderName, HeaderValues, Headers, HttpContext, KnownHeaderName, Method, ProtocolSession,
8 ReceivedBodyState, TypeSet, Version::Http1_1,
9};
10use trillium_server_common::{
11 ArcedConnector, ArcedQuicClientConfig, Connector, QuicClientConfig, Transport,
12 url::{Origin, Url},
13};
14
15const DEFAULT_H2_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
19
20const DEFAULT_H2_IDLE_PING_THRESHOLD: Duration = Duration::from_secs(10);
23
24const DEFAULT_H2_IDLE_PING_TIMEOUT: Duration = Duration::from_secs(20);
27
28#[derive(Clone, Debug, fieldwork::Fieldwork)]
31pub struct Client {
32 config: ArcedConnector,
33
34 #[field(vis = "pub(crate)", get)]
35 h3: Option<H3ClientState>,
36
37 #[field(vis = "pub(crate)", get)]
38 pool: Option<Pool<Origin, Box<dyn Transport>>>,
39
40 #[field(vis = "pub(crate)", get)]
41 h2_pool: Option<Pool<Origin, H2Pooled>>,
42
43 #[field(get, set, with, without, copy)]
47 h2_idle_timeout: Option<Duration>,
48
49 #[field(get, set, with, copy, without)]
54 h2_idle_ping_threshold: Option<Duration>,
55
56 #[field(get, set, with, copy)]
63 h2_idle_ping_timeout: Duration,
64
65 #[field(get)]
67 base: Option<Arc<Url>>,
68
69 #[field(get)]
71 default_headers: Arc<Headers>,
72
73 #[field(get, set, with, copy, without, option_set_some)]
75 timeout: Option<Duration>,
76
77 #[field(get, get_mut, set, with, into)]
79 context: Arc<HttpContext>,
80
81 #[field(vis = "pub(crate)", get)]
85 handler: ArcedClientHandler,
86}
87
88macro_rules! method {
89 ($fn_name:ident, $method:ident) => {
90 method!(
91 $fn_name,
92 $method,
93 concat!(
94 "Builds a new client conn with the ",
96 stringify!($fn_name),
97 " http method and the provided url.
98
99```
100use trillium_client::{Client, Method};
101use trillium_testing::client_config;
102
103let client = Client::new(client_config());
104let conn = client.",
105 stringify!($fn_name),
106 "(\"http://localhost:8080/some/route\"); //<-
107
108assert_eq!(conn.method(), Method::",
109 stringify!($method),
110 ");
111assert_eq!(conn.url().to_string(), \"http://localhost:8080/some/route\");
112```
113"
114 )
115 );
116 };
117
118 ($fn_name:ident, $method:ident, $doc_comment:expr_2021) => {
119 #[doc = $doc_comment]
120 pub fn $fn_name(&self, url: impl IntoUrl) -> Conn {
121 self.build_conn(Method::$method, url)
122 }
123 };
124}
125
126pub(crate) fn default_request_headers() -> Headers {
127 Headers::new()
128 .with_inserted_header(KnownHeaderName::UserAgent, USER_AGENT)
129 .with_inserted_header(KnownHeaderName::Accept, "*/*")
130}
131
132impl Client {
133 method!(get, Get);
134
135 method!(post, Post);
136
137 method!(put, Put);
138
139 method!(delete, Delete);
140
141 method!(patch, Patch);
142
143 pub fn new(connector: impl Connector) -> Self {
145 Self {
146 config: ArcedConnector::new(connector),
147 h3: None,
148 pool: Some(Pool::default()),
149 h2_pool: Some(Pool::default()),
150 h2_idle_timeout: Some(DEFAULT_H2_IDLE_TIMEOUT),
151 h2_idle_ping_threshold: Some(DEFAULT_H2_IDLE_PING_THRESHOLD),
152 h2_idle_ping_timeout: DEFAULT_H2_IDLE_PING_TIMEOUT,
153 base: None,
154 default_headers: Arc::new(default_request_headers()),
155 timeout: None,
156 context: Default::default(),
157 handler: ArcedClientHandler::new(()),
158 }
159 }
160
161 pub fn new_with_quic<C: Connector, Q: QuicClientConfig<C>>(connector: C, quic: Q) -> Self {
171 let arced_quic = ArcedQuicClientConfig::new(&connector, quic);
173
174 #[cfg_attr(not(feature = "webtransport"), allow(unused_mut))]
175 let mut context = HttpContext::default();
176 #[cfg(feature = "webtransport")]
177 {
178 context
183 .config_mut()
184 .set_h3_datagrams_enabled(true)
185 .set_webtransport_enabled(true)
186 .set_extended_connect_enabled(true);
187 }
188
189 Self {
190 config: ArcedConnector::new(connector),
191 h3: Some(H3ClientState::new(arced_quic)),
192 pool: Some(Pool::default()),
193 h2_pool: Some(Pool::default()),
194 h2_idle_timeout: Some(DEFAULT_H2_IDLE_TIMEOUT),
195 h2_idle_ping_threshold: Some(DEFAULT_H2_IDLE_PING_THRESHOLD),
196 h2_idle_ping_timeout: DEFAULT_H2_IDLE_PING_TIMEOUT,
197 base: None,
198 default_headers: Arc::new(default_request_headers()),
199 timeout: None,
200 context: Arc::new(context),
201 handler: ArcedClientHandler::new(()),
202 }
203 }
204
205 #[must_use]
214 pub fn with_handler<H: ClientHandler>(mut self, handler: H) -> Self {
215 self.set_handler(handler);
216 self
217 }
218
219 pub fn set_handler<H: ClientHandler>(&mut self, handler: H) -> &mut Self {
222 self.handler = ArcedClientHandler::new(handler);
223 self
224 }
225
226 pub fn downcast_handler<T: Any + 'static>(&self) -> Option<&T> {
232 self.handler.downcast_ref()
233 }
234
235 pub fn without_default_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
237 self.default_headers_mut().remove(name);
238 self
239 }
240
241 pub fn with_default_header(
243 mut self,
244 name: impl Into<HeaderName<'static>>,
245 value: impl Into<HeaderValues>,
246 ) -> Self {
247 self.default_headers_mut().insert(name, value);
248 self
249 }
250
251 pub fn default_headers_mut(&mut self) -> &mut Headers {
255 Arc::make_mut(&mut self.default_headers)
256 }
257
258 pub fn without_keepalive(mut self) -> Self {
267 self.pool = None;
268 self.h2_pool = None;
269 self
270 }
271
272 pub fn build_conn<M>(&self, method: M, url: impl IntoUrl) -> Conn
289 where
290 M: TryInto<Method>,
291 <M as TryInto<Method>>::Error: Debug,
292 {
293 let method = method.try_into().unwrap();
294 let (url, request_target) = if let Some(base) = &self.base
295 && let Some(request_target) = url.request_target(method)
296 {
297 ((**base).clone(), Some(request_target))
298 } else {
299 (self.build_url(url).unwrap(), None)
300 };
301
302 Conn {
303 url,
304 method,
305 request_headers: Headers::clone(&self.default_headers),
306 response_headers: Headers::new(),
307 transport: None,
308 status: None,
309 request_body: None,
310 protocol_session: ProtocolSession::Http1,
311 #[cfg(feature = "webtransport")]
312 wt_pool_entry: None,
313 buffer: Vec::with_capacity(128).into(),
314 response_body_state: ReceivedBodyState::Start,
315 headers_finalized: false,
316 halted: false,
317 error: None,
318 body_override: None,
319 timeout: self.timeout,
320 http_version: Http1_1,
321 max_head_length: 8 * 1024,
322 state: TypeSet::new(),
323 context: self.context.clone(),
324 authority: None,
325 scheme: None,
326 path: None,
327 request_target,
328 protocol: None,
329 request_trailers: None,
330 response_trailers: None,
331 client: self.clone(),
332 followup: None,
333 }
334 }
335
336 pub fn connector(&self) -> &ArcedConnector {
338 &self.config
339 }
340
341 pub fn clean_up_pool(&self) {
346 if let Some(pool) = &self.pool {
347 pool.cleanup();
348 }
349 if let Some(h2_pool) = &self.h2_pool {
350 h2_pool.cleanup();
351 }
352 }
353
354 pub fn with_base(mut self, base: impl IntoUrl) -> Self {
356 self.set_base(base).unwrap();
357 self
358 }
359
360 pub fn build_url(&self, url: impl IntoUrl) -> crate::Result<Url> {
362 url.into_url(self.base())
363 }
364
365 pub fn set_base(&mut self, base: impl IntoUrl) -> crate::Result<()> {
367 let mut base = base.into_url(None)?;
368
369 if !base.path().ends_with('/') {
370 log::warn!("appending a trailing / to {base}");
371 base.set_path(&format!("{}/", base.path()));
372 }
373
374 self.base = Some(Arc::new(base));
375 Ok(())
376 }
377
378 pub fn base_mut(&mut self) -> Option<&mut Url> {
383 let base = self.base.as_mut()?;
384 Some(Arc::make_mut(base))
385 }
386}
387
388impl<T: Connector> From<T> for Client {
389 fn from(connector: T) -> Self {
390 Self::new(connector)
391 }
392}