Skip to main content

trillium_client/conn/
shared.rs

1use super::{Body, Conn, ReceivedBody, ReceivedBodyState, Transport, TypeSet, encoding};
2use crate::{Error, Result, Version, pool::PoolEntry};
3use futures_lite::{AsyncWriteExt, io};
4use std::{
5    fmt::{self, Debug, Formatter},
6    future::{Future, IntoFuture},
7    mem,
8    pin::Pin,
9};
10use trillium_http::Upgrade;
11
12/// A wrapper error for [`trillium_http::Error`] or, depending on json serializer feature, either
13/// `sonic_rs::Error` or `serde_json::Error`. Only available when either the `sonic-rs` or
14/// `serde_json` cargo features are enabled.
15#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
16#[derive(thiserror::Error, Debug)]
17pub enum ClientSerdeError {
18    /// A [`trillium_http::Error`]
19    #[error(transparent)]
20    HttpError(#[from] Error),
21
22    #[cfg(feature = "sonic-rs")]
23    /// A [`sonic_rs::Error`]
24    #[error(transparent)]
25    JsonError(#[from] sonic_rs::Error),
26
27    #[cfg(feature = "serde_json")]
28    /// A [`serde_json::Error`]
29    #[error(transparent)]
30    JsonError(#[from] serde_json::Error),
31}
32
33impl Conn {
34    pub(crate) async fn exec(&mut self) -> Result<()> {
35        if matches!(self.http_version, Version::Http0_9) {
36            return Err(Error::UnsupportedVersion(self.http_version));
37        }
38
39        if self.try_exec_h3().await? {
40            return Ok(());
41        }
42        if self.try_exec_h2_pooled().await? {
43            return Ok(());
44        }
45
46        // h2 prior knowledge: `http_version = Http2` is an assertion that the server speaks
47        // h2, so we skip h1 entirely. Over `http://` this is h2c (cleartext immediate
48        // preface); over `https://` it bypasses ALPN-readback and starts the h2 driver
49        // directly after the TLS handshake — useful for TLS connectors that don't expose
50        // `negotiated_alpn` (e.g. native-tls today). Either way, there's no fallback path:
51        // a server that doesn't actually speak h2 surfaces as a plain IO error.
52        if self.http_version == Version::Http2 {
53            return self.exec_h2_prior_knowledge().await;
54        }
55
56        self.exec_h1_or_promote_h2().await
57    }
58
59    pub(crate) fn body_len(&self) -> Option<u64> {
60        if let Some(ref body) = self.request_body {
61            body.len()
62        } else {
63            Some(0)
64        }
65    }
66
67    pub(crate) fn finalize_headers(&mut self) -> Result<()> {
68        match self.http_version {
69            Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
70            Version::Http2 => self.finalize_headers_h2(),
71            Version::Http3 if self.h3_client_state.is_some() => self.finalize_headers_h3(),
72            other => Err(Error::UnsupportedVersion(other)),
73        }
74    }
75}
76
77impl Drop for Conn {
78    fn drop(&mut self) {
79        log::trace!("dropping client conn");
80        let Some(mut transport) = self.transport.take() else {
81            log::trace!("no transport, nothing to do");
82
83            return;
84        };
85
86        if !self.is_keep_alive() {
87            log::trace!("not keep alive, closing");
88
89            self.config
90                .runtime()
91                .clone()
92                .spawn(async move { transport.close().await });
93
94            return;
95        }
96
97        let Ok(Some(peer_addr)) = transport.peer_addr() else {
98            return;
99        };
100        let Some(pool) = self.pool.take() else { return };
101
102        let origin = self.url.origin();
103
104        if self.response_body_state == ReceivedBodyState::End {
105            log::trace!(
106                "response body has been read to completion, checking transport back into pool for \
107                 {}",
108                &peer_addr
109            );
110            pool.insert(origin, PoolEntry::new(transport, None));
111        } else {
112            let content_length = self.response_content_length();
113            let buffer = mem::take(&mut self.buffer);
114            let response_body_state = self.response_body_state;
115            let encoding = encoding(&self.response_headers);
116            self.config.runtime().spawn(async move {
117                let mut response_body = ReceivedBody::new(
118                    content_length,
119                    buffer,
120                    transport,
121                    response_body_state,
122                    None,
123                    encoding,
124                );
125
126                match io::copy(&mut response_body, io::sink()).await {
127                    Ok(bytes) => {
128                        let transport = response_body.take_transport().unwrap();
129                        log::trace!(
130                            "read {} bytes in order to recycle conn for {}",
131                            bytes,
132                            &peer_addr
133                        );
134                        pool.insert(origin, PoolEntry::new(transport, None));
135                    }
136
137                    Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror),
138                };
139            });
140        }
141    }
142}
143
144impl From<Conn> for Body {
145    fn from(conn: Conn) -> Body {
146        let received_body: ReceivedBody<'static, _> = conn.into();
147        received_body.into()
148    }
149}
150
151impl From<Conn> for ReceivedBody<'static, Box<dyn Transport>> {
152    fn from(mut conn: Conn) -> Self {
153        let _ = conn.finalize_headers();
154        let runtime = conn.config.runtime();
155        let origin = conn.url.origin();
156
157        let on_completion = if conn.is_keep_alive()
158            && let Some(pool) = conn.pool.take()
159        {
160            Box::new(move |transport: Box<dyn Transport>| {
161                log::trace!("body transferred, returning to pool");
162                pool.insert(origin.clone(), PoolEntry::new(transport, None));
163            }) as Box<dyn FnOnce(Box<dyn Transport>) + Send + Sync + 'static>
164        } else {
165            Box::new(move |mut transport: Box<dyn Transport>| {
166                runtime.spawn(async move { transport.close().await });
167            }) as Box<dyn FnOnce(Box<dyn Transport>) + Send + Sync + 'static>
168        };
169
170        ReceivedBody::new(
171            conn.response_content_length(),
172            mem::take(&mut conn.buffer),
173            conn.transport.take().unwrap(),
174            conn.response_body_state,
175            Some(on_completion),
176            conn.response_encoding(),
177        )
178    }
179}
180
181impl From<Conn> for Upgrade<Box<dyn Transport>> {
182    fn from(mut conn: Conn) -> Self {
183        Upgrade::new(
184            mem::take(&mut conn.request_headers),
185            conn.url.path().to_string(),
186            conn.method,
187            conn.transport.take().unwrap(),
188            mem::take(&mut conn.buffer),
189            conn.http_version(),
190        )
191    }
192}
193
194impl IntoFuture for Conn {
195    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
196    type Output = Result<Conn>;
197
198    fn into_future(mut self) -> Self::IntoFuture {
199        Box::pin(async move { (&mut self).await.map(|()| self) })
200    }
201}
202
203impl<'conn> IntoFuture for &'conn mut Conn {
204    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
205    type Output = Result<()>;
206
207    fn into_future(self) -> Self::IntoFuture {
208        Box::pin(async move {
209            if let Some(duration) = self.timeout {
210                self.config
211                    .runtime()
212                    .timeout(duration, self.exec())
213                    .await
214                    .unwrap_or(Err(Error::TimedOut("Conn", duration)))?;
215            } else {
216                self.exec().await?;
217            }
218            Ok(())
219        })
220    }
221}
222
223impl Debug for Conn {
224    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225        f.debug_struct("Conn")
226            .field("authority", &self.authority)
227            .field("buffer", &String::from_utf8_lossy(&self.buffer))
228            .field("config", &self.config)
229            .field("h3_client_state", &self.h3_client_state)
230            .field("protocol_session", &self.protocol_session)
231            .field("http_version", &self.http_version)
232            .field("method", &self.method)
233            .field("path", &self.path)
234            .field("pool", &self.pool)
235            .field("request_body", &self.request_body)
236            .field("request_headers", &self.request_headers)
237            .field("request_target", &self.request_target)
238            .field("request_trailers", &self.request_trailers)
239            .field("response_body_state", &self.response_body_state)
240            .field("response_headers", &self.response_headers)
241            .field("response_trailers", &self.response_trailers)
242            .field("scheme", &self.scheme)
243            .field("state", &self.state)
244            .field("status", &self.status)
245            .field("url", &self.url)
246            .finish()
247    }
248}
249
250impl AsRef<TypeSet> for Conn {
251    fn as_ref(&self) -> &TypeSet {
252        &self.state
253    }
254}
255
256impl AsMut<TypeSet> for Conn {
257    fn as_mut(&mut self) -> &mut TypeSet {
258        &mut self.state
259    }
260}