1use super::{Body, Conn, Transport, TypeSet};
2use crate::{ClientHandler, ConnExt, Error, Result, Version};
3use smallvec::SmallVec;
4#[cfg(feature = "hickory")]
5use std::net::IpAddr;
6use std::{
7 borrow::Cow,
8 fmt::{self, Debug, Formatter},
9 future::{Future, IntoFuture},
10 mem,
11 net::SocketAddr,
12 pin::Pin,
13};
14use trillium_http::{ProtocolSession, Upgrade};
15use trillium_server_common::Destination;
16
17#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
21#[derive(thiserror::Error, Debug)]
22pub enum ClientSerdeError {
23 #[error(transparent)]
25 HttpError(#[from] Error),
26
27 #[cfg(feature = "sonic-rs")]
28 #[error(transparent)]
30 JsonError(#[from] sonic_rs::Error),
31
32 #[cfg(feature = "serde_json")]
33 #[error(transparent)]
35 JsonError(#[from] serde_json::Error),
36}
37
38impl Conn {
39 pub(crate) async fn exec(&mut self) -> Result<()> {
40 if let Some(error) = self.error.take() {
44 return Err(error);
45 }
46
47 let handler = self.client.arc_handler().clone();
49 handler.run(self).await?;
50
51 if !self.halted {
52 if let Err(e) = self.exec_network().await {
55 self.error = Some(e);
56 }
57 } else {
58 log::trace!("conn is halted, skipping network round-trip");
59 }
60
61 handler.after_response(self).await?;
63
64 if let Some(e) = self.error.take() {
65 Err(e)
66 } else {
67 Ok(())
68 }
69 }
70
71 async fn exec_network(&mut self) -> Result<()> {
72 if self.http_version == Some(Version::Http0_9) {
73 return Err(Error::UnsupportedVersion(Version::Http0_9));
74 }
75
76 if self.try_reuse_h3_pool().await? {
82 return Ok(());
83 }
84 if self.try_exec_h2_pooled().await? {
85 return Ok(());
86 }
87
88 if self.try_establish_h3().await? {
91 return Ok(());
92 }
93
94 if self.http_version == Some(Version::Http2) {
98 return self.exec_h2_prior_knowledge().await;
99 }
100
101 self.exec_h1_or_promote_h2().await
102 }
103
104 pub(crate) fn body_len(&self) -> Option<u64> {
105 if let Some(ref body) = self.request_body {
106 body.len()
107 } else {
108 Some(0)
109 }
110 }
111
112 pub(crate) fn finalize_headers(&mut self) -> Result<()> {
113 match self.http_version() {
114 Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
115 Version::Http2 => self.finalize_headers_h2(),
116 Version::Http3 if self.client.h3().is_some() => self.finalize_headers_h3(),
117 other => Err(Error::UnsupportedVersion(other)),
118 }
119 }
120
121 pub(crate) async fn origin_destination(&self) -> Result<Destination> {
130 let mut destination = Destination::from_url(&self.url)?;
131 let addrs = self.origin_socket_addrs().await?;
132 if !addrs.is_empty() {
133 destination.set_addrs(addrs);
134 }
135 match self.http_version {
136 Some(Version::Http1_0 | Version::Http1_1) => {
137 destination.set_alpn([Cow::Borrowed(b"http/1.1".as_slice())]);
138 }
139 Some(Version::Http2) => {
140 destination.set_alpn([Cow::Borrowed(b"h2".as_slice())]);
141 }
142 _ => {}
143 }
144 Ok(destination)
145 }
146
147 pub(crate) async fn origin_socket_addrs(&self) -> Result<SmallVec<[SocketAddr; 4]>> {
151 let Some(host) = self.url.host_str() else {
152 return Ok(SmallVec::new());
153 };
154 let port = self.url.port_or_known_default().unwrap_or(443);
155 self.resolve_socket_addrs(host, port).await
156 }
157}
158
159#[cfg(feature = "hickory")]
160impl Conn {
161 pub(crate) async fn resolve(
174 &self,
175 host: &str,
176 port: u16,
177 ) -> Result<Option<crate::dns::Resolved>> {
178 if host.parse::<IpAddr>().is_ok() {
179 return Ok(None);
180 }
181 match &self.client.resolver {
182 Some(resolver) => Ok(Some(
183 resolver
184 .resolve(&self.client, host, port, self.timeout)
185 .await?,
186 )),
187 None => Ok(None),
188 }
189 }
190
191 pub(crate) async fn resolve_socket_addrs(
192 &self,
193 host: &str,
194 port: u16,
195 ) -> Result<SmallVec<[SocketAddr; 4]>> {
196 Ok(self
197 .resolve(host, port)
198 .await?
199 .map(|resolved| resolved.socket_addrs(port))
200 .unwrap_or_default())
201 }
202}
203
204#[cfg(not(feature = "hickory"))]
205impl Conn {
206 pub(crate) async fn resolve_socket_addrs(
207 &self,
208 _host: &str,
209 _port: u16,
210 ) -> Result<SmallVec<[SocketAddr; 4]>> {
211 Ok(SmallVec::new())
212 }
213}
214
215impl Drop for Conn {
216 fn drop(&mut self) {
217 log::trace!("dropping client conn");
218 drop(self.take_response_body());
219 }
220}
221
222impl From<Conn> for Body {
223 fn from(mut conn: Conn) -> Body {
224 if let Some(body) = conn.body_override.take() {
227 return body;
228 }
229
230 match conn.take_received_body(true) {
231 Some(rb) => rb.into(),
232 None => Body::default(),
233 }
234 }
235}
236
237impl From<Conn> for Upgrade<Box<dyn Transport>> {
238 fn from(mut conn: Conn) -> Self {
247 let path = conn.path.take().unwrap_or_else(|| match conn.url.query() {
250 Some(q) => Cow::Owned(format!("{}?{q}", conn.url.path())),
251 None => Cow::Owned(conn.url.path().to_owned()),
252 });
253 let secure = conn.url.scheme() == "https";
254
255 Upgrade::from_parts(
256 mem::take(&mut conn.response_headers),
257 mem::take(&mut conn.request_headers),
258 path,
259 conn.method,
260 conn.transport
261 .take()
262 .expect("client conn has no transport — request not yet sent"),
263 mem::take(&mut conn.buffer),
264 mem::take(&mut conn.state),
265 conn.context.clone(),
266 None,
267 conn.authority.take(),
268 conn.scheme.take(),
269 mem::replace(&mut conn.protocol_session, ProtocolSession::Http1),
270 conn.protocol.take(),
271 conn.http_version(),
272 conn.status,
273 secure,
274 mem::take(&mut conn.response_body_state),
276 conn.response_trailers.take(),
279 )
280 }
281}
282
283impl IntoFuture for Conn {
284 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
285 type Output = Result<Conn>;
286
287 fn into_future(mut self) -> Self::IntoFuture {
288 Box::pin(async move { (&mut self).await.map(|()| self) })
289 }
290}
291
292impl<'conn> IntoFuture for &'conn mut Conn {
293 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
294 type Output = Result<()>;
295
296 fn into_future(self) -> Self::IntoFuture {
297 Box::pin(async move {
298 loop {
301 let result = if let Some(duration) = self.timeout {
302 self.client
303 .connector()
304 .runtime()
305 .timeout(duration, self.exec())
306 .await
307 .unwrap_or(Err(Error::TimedOut("Conn", duration)))
308 } else {
309 self.exec().await
310 };
311
312 self.halted = false;
314
315 if let Err(e) = result {
316 self.followup = None;
319 return Err(e);
320 }
321
322 let Some(next) = self.take_followup() else {
323 break;
324 };
325
326 if let Some(body) = self.take_response_body() {
327 body.recycle().await;
328 }
329
330 let _displaced = mem::replace(self, next);
331 }
332 Ok(())
333 })
334 }
335}
336
337impl Debug for Conn {
338 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
339 f.debug_struct("Conn")
340 .field("authority", &self.authority)
341 .field("buffer", &String::from_utf8_lossy(&self.buffer))
342 .field("client", &self.client)
343 .field("protocol_session", &self.protocol_session)
344 .field("http_version", &self.http_version)
345 .field("method", &self.method)
346 .field("path", &self.path)
347 .field("request_body", &self.request_body)
348 .field("request_headers", &self.request_headers)
349 .field("request_target", &self.request_target)
350 .field("request_trailers", &self.request_trailers)
351 .field("response_body_state", &self.response_body_state)
352 .field("response_headers", &self.response_headers)
353 .field("response_trailers", &self.response_trailers)
354 .field("scheme", &self.scheme)
355 .field("state", &self.state)
356 .field("status", &self.status)
357 .field("url", &self.url)
358 .finish()
359 }
360}
361
362impl AsRef<TypeSet> for Conn {
363 fn as_ref(&self) -> &TypeSet {
364 &self.state
365 }
366}
367
368impl AsMut<TypeSet> for Conn {
369 fn as_mut(&mut self) -> &mut TypeSet {
370 &mut self.state
371 }
372}