1use crate::{pool::PoolEntry, util::encoding, Pool};
2use encoding_rs::Encoding;
3use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt};
4use memchr::memmem::Finder;
5use size::{Base, Size};
6use std::{
7 fmt::{self, Debug, Display, Formatter},
8 future::{Future, IntoFuture},
9 io::{ErrorKind, Write},
10 ops::{Deref, DerefMut},
11 pin::Pin,
12 str::FromStr,
13 sync::Arc,
14};
15use trillium_http::{
16 transport::BoxedTransport,
17 Body, Error, HeaderName, HeaderValue, HeaderValues, Headers,
18 KnownHeaderName::{Connection, ContentLength, Expect, Host, TransferEncoding},
19 Method, ReceivedBody, ReceivedBodyState, Result, StateSet, Status, Stopper, Upgrade,
20};
21use trillium_server_common::{
22 url::{Origin, Url},
23 Connector, ObjectSafeConnector, Transport,
24};
25
26const MAX_HEADERS: usize = 128;
27const MAX_HEAD_LENGTH: usize = 2 * 1024;
28
29#[cfg(feature = "json")]
35#[derive(thiserror::Error, Debug)]
36pub enum ClientSerdeError {
37 #[error(transparent)]
39 HttpError(#[from] Error),
40
41 #[error(transparent)]
43 JsonError(#[from] serde_json::Error),
44}
45
46#[must_use]
51pub struct Conn {
52 pub(crate) url: Url,
53 pub(crate) method: Method,
54 pub(crate) request_headers: Headers,
55 pub(crate) response_headers: Headers,
56 pub(crate) transport: Option<BoxedTransport>,
57 pub(crate) status: Option<Status>,
58 pub(crate) request_body: Option<Body>,
59 pub(crate) pool: Option<Pool<Origin, BoxedTransport>>,
60 pub(crate) buffer: trillium_http::Buffer,
61 pub(crate) response_body_state: ReceivedBodyState,
62 pub(crate) config: Arc<dyn ObjectSafeConnector>,
63 pub(crate) headers_finalized: bool,
64}
65
66pub const USER_AGENT: &str = concat!("trillium-client/", env!("CARGO_PKG_VERSION"));
68
69impl Debug for Conn {
70 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
71 f.debug_struct("Conn")
72 .field("url", &self.url)
73 .field("method", &self.method)
74 .field("request_headers", &self.request_headers)
75 .field("response_headers", &self.response_headers)
76 .field("status", &self.status)
77 .field("request_body", &self.request_body)
78 .field("pool", &self.pool)
79 .field("buffer", &String::from_utf8_lossy(&self.buffer))
80 .field("response_body_state", &self.response_body_state)
81 .field("config", &self.config)
82 .finish()
83 }
84}
85
86impl Conn {
87 pub fn request_headers(&self) -> &Headers {
89 &self.request_headers
90 }
91
92 pub fn with_request_header(
121 mut self,
122 name: impl Into<HeaderName<'static>>,
123 value: impl Into<HeaderValues>,
124 ) -> Self {
125 self.request_headers.insert(name, value);
126 self
127 }
128
129 #[deprecated = "use Conn::with_request_header"]
130 pub fn with_header(
132 self,
133 name: impl Into<HeaderName<'static>>,
134 value: impl Into<HeaderValues>,
135 ) -> Self {
136 self.with_request_header(name, value)
137 }
138
139 pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
168 where
169 I: IntoIterator<Item = (HN, HV)> + Send,
170 HN: Into<HeaderName<'static>>,
171 HV: Into<HeaderValues>,
172 {
173 self.request_headers.extend(headers);
174 self
175 }
176
177 #[deprecated = "use Conn::with_request_headers"]
179 pub fn with_headers<HN, HV, I>(self, headers: I) -> Self
180 where
181 I: IntoIterator<Item = (HN, HV)> + Send,
182 HN: Into<HeaderName<'static>>,
183 HV: Into<HeaderValues>,
184 {
185 self.with_request_headers(headers)
186 }
187
188 pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
190 self.request_headers.remove(name);
191 self
192 }
193
194 #[deprecated = "use Conn::without_request_header"]
196 pub fn without_header(self, name: impl Into<HeaderName<'static>>) -> Self {
197 self.without_request_header(name)
198 }
199
200 pub fn response_headers(&self) -> &Headers {
222 &self.response_headers
223 }
224
225 pub fn request_headers_mut(&mut self) -> &mut Headers {
259 &mut self.request_headers
260 }
261
262 pub fn response_headers_mut(&mut self) -> &mut Headers {
264 &mut self.response_headers
265 }
266
267 pub fn set_request_body(&mut self, body: impl Into<Body>) {
296 self.request_body = Some(body.into());
297 }
298
299 pub fn with_body(mut self, body: impl Into<Body>) -> Self {
328 self.set_request_body(body);
329 self
330 }
331
332 #[cfg(feature = "json")]
336 pub fn with_json_body(self, body: &impl serde::Serialize) -> serde_json::Result<Self> {
337 use trillium_http::KnownHeaderName;
338
339 Ok(self
340 .with_body(serde_json::to_string(body)?)
341 .with_request_header(KnownHeaderName::ContentType, "application/json"))
342 }
343
344 pub(crate) fn response_encoding(&self) -> &'static Encoding {
345 encoding(&self.response_headers)
346 }
347
348 pub fn url(&self) -> &Url {
362 &self.url
363 }
364
365 pub fn method(&self) -> Method {
382 self.method
383 }
384
385 #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)]
413 pub fn response_body(&mut self) -> ReceivedBody<'_, BoxedTransport> {
414 ReceivedBody::new(
415 self.response_content_length(),
416 &mut self.buffer,
417 self.transport.as_mut().unwrap(),
418 &mut self.response_body_state,
419 None,
420 encoding(&self.response_headers),
421 )
422 }
423
424 #[cfg(feature = "json")]
428 pub async fn response_json<T>(&mut self) -> std::result::Result<T, ClientSerdeError>
429 where
430 T: serde::de::DeserializeOwned,
431 {
432 let body = self.response_body().read_string().await?;
433 Ok(serde_json::from_str(&body)?)
434 }
435
436 pub(crate) fn response_content_length(&self) -> Option<u64> {
437 if self.status == Some(Status::NoContent)
438 || self.status == Some(Status::NotModified)
439 || self.method == Method::Head
440 {
441 Some(0)
442 } else {
443 self.response_headers
444 .get_str(ContentLength)
445 .and_then(|c| c.parse().ok())
446 }
447 }
448
449 pub fn status(&self) -> Option<Status> {
471 self.status
472 }
473
474 pub fn success(self) -> std::result::Result<Self, UnexpectedStatusError> {
497 match self.status() {
498 Some(status) if status.is_success() => Ok(self),
499 _ => Err(self.into()),
500 }
501 }
502
503 pub async fn recycle(mut self) {
510 if self.is_keep_alive() && self.transport.is_some() && self.pool.is_some() {
511 self.finish_reading_body().await;
512 }
513 }
514
515 pub fn peer_addr(&self) -> Option<std::net::SocketAddr> {
517 self.transport
518 .as_ref()
519 .and_then(|t| t.peer_addr().ok().flatten())
520 }
521
522 fn finalize_headers(&mut self) -> Result<()> {
525 if self.headers_finalized {
526 return Ok(());
527 }
528
529 let host = self.url.host_str().ok_or(Error::UnexpectedUriFormat)?;
530
531 self.request_headers.try_insert_with(Host, || {
532 self.url
533 .port()
534 .map_or_else(|| host.to_string(), |port| format!("{host}:{port}"))
535 });
536
537 if self.pool.is_none() {
538 self.request_headers.try_insert(Connection, "close");
539 }
540
541 match self.body_len() {
542 Some(0) => {}
543 Some(len) => {
544 self.request_headers.insert(Expect, "100-continue");
545 self.request_headers.insert(ContentLength, len.to_string());
546 }
547 None => {
548 self.request_headers.insert(Expect, "100-continue");
549 self.request_headers.insert(TransferEncoding, "chunked");
550 }
551 }
552
553 self.headers_finalized = true;
554 Ok(())
555 }
556
557 fn body_len(&self) -> Option<u64> {
558 if let Some(ref body) = self.request_body {
559 body.len()
560 } else {
561 Some(0)
562 }
563 }
564
565 async fn find_pool_candidate(&self, head: &[u8]) -> Result<Option<BoxedTransport>> {
566 let mut byte = [0];
567 if let Some(pool) = &self.pool {
568 for mut candidate in pool.candidates(&self.url.origin()) {
569 if poll_once(candidate.read(&mut byte)).await.is_none()
570 && candidate.write_all(head).await.is_ok()
571 {
572 return Ok(Some(candidate));
573 }
574 }
575 }
576 Ok(None)
577 }
578
579 async fn connect_and_send_head(&mut self) -> Result<()> {
580 if self.transport.is_some() {
581 return Err(Error::Io(std::io::Error::new(
582 ErrorKind::AlreadyExists,
583 "conn already connected",
584 )));
585 }
586
587 let head = self.build_head().await?;
588
589 let transport = match self.find_pool_candidate(&head).await? {
590 Some(transport) => {
591 log::debug!("reusing connection to {:?}", transport.peer_addr()?);
592 transport
593 }
594
595 None => {
596 let mut transport = Connector::connect(&self.config, &self.url).await?;
597 log::debug!("opened new connection to {:?}", transport.peer_addr()?);
598 transport.write_all(&head).await?;
599 transport
600 }
601 };
602
603 self.transport = Some(transport);
604 Ok(())
605 }
606
607 async fn build_head(&mut self) -> Result<Vec<u8>> {
608 let mut buf = Vec::with_capacity(128);
609 let url = &self.url;
610 let method = self.method;
611 write!(buf, "{method} ")?;
612
613 if method == Method::Connect {
614 let host = url.host_str().ok_or(Error::UnexpectedUriFormat)?;
615
616 let port = url
617 .port_or_known_default()
618 .ok_or(Error::UnexpectedUriFormat)?;
619
620 write!(buf, "{host}:{port}")?;
621 } else {
622 write!(buf, "{}", url.path())?;
623 if let Some(query) = url.query() {
624 write!(buf, "?{query}")?;
625 }
626 }
627
628 write!(buf, " HTTP/1.1\r\n")?;
629
630 for (name, values) in &self.request_headers {
631 if !name.is_valid() {
632 return Err(Error::MalformedHeader(name.to_string().into()));
633 }
634
635 for value in values {
636 if !value.is_valid() {
637 return Err(Error::MalformedHeader(
638 format!("value for {name}: {value:?}").into(),
639 ));
640 }
641 write!(buf, "{name}: ")?;
642 buf.extend_from_slice(value.as_ref());
643 write!(buf, "\r\n")?;
644 }
645 }
646
647 write!(buf, "\r\n")?;
648 log::trace!(
649 "{}",
650 std::str::from_utf8(&buf).unwrap().replace("\r\n", "\r\n> ")
651 );
652
653 Ok(buf)
654 }
655
656 fn transport(&mut self) -> &mut BoxedTransport {
657 self.transport.as_mut().unwrap()
658 }
659
660 async fn read_head(&mut self) -> Result<usize> {
661 let Self {
662 buffer,
663 transport: Some(transport),
664 ..
665 } = self
666 else {
667 return Err(Error::Closed);
668 };
669
670 let mut len = buffer.len();
671 let mut search_start = 0;
672 let finder = Finder::new(b"\r\n\r\n");
673
674 if len > 0 {
675 if let Some(index) = finder.find(buffer) {
676 return Ok(index + 4);
677 }
678 search_start = len.saturating_sub(3);
679 }
680
681 loop {
682 buffer.expand();
683 let bytes = transport.read(&mut buffer[len..]).await?;
684 len += bytes;
685
686 let search = finder.find(&buffer[search_start..len]);
687
688 if let Some(index) = search {
689 buffer.truncate(len);
690 return Ok(search_start + index + 4);
691 }
692
693 search_start = len.saturating_sub(3);
694
695 if bytes == 0 {
696 if len == 0 {
697 return Err(Error::Closed);
698 } else {
699 return Err(Error::PartialHead);
700 }
701 }
702
703 if len >= MAX_HEAD_LENGTH {
704 return Err(Error::HeadersTooLong);
705 }
706 }
707 }
708
709 async fn parse_head(&mut self) -> Result<()> {
710 let head_offset = self.read_head().await?;
711 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
712 let mut httparse_res = httparse::Response::new(&mut headers);
713 let parse_result = httparse_res.parse(&self.buffer[..head_offset])?;
714
715 match parse_result {
716 httparse::Status::Complete(n) if n == head_offset => {}
717 _ => return Err(Error::PartialHead),
718 }
719
720 self.status = httparse_res.code.map(|code| code.try_into().unwrap());
721
722 self.response_headers.reserve(httparse_res.headers.len());
723 for header in httparse_res.headers {
724 let header_name = HeaderName::from_str(header.name)?;
725 let header_value = HeaderValue::from(header.value.to_owned());
726 self.response_headers.append(header_name, header_value);
727 }
728
729 self.buffer.ignore_front(head_offset);
730
731 self.validate_response_headers()?;
732 Ok(())
733 }
734
735 async fn send_body_and_parse_head(&mut self) -> Result<()> {
736 if self
737 .request_headers
738 .eq_ignore_ascii_case(Expect, "100-continue")
739 {
740 log::trace!("Expecting 100-continue");
741 self.parse_head().await?;
742 if self.status == Some(Status::Continue) {
743 self.status = None;
744 log::trace!("Received 100-continue, sending request body");
745 } else {
746 self.request_body.take();
747 log::trace!(
748 "Received a status code other than 100-continue, not sending request body"
749 );
750 return Ok(());
751 }
752 }
753
754 self.send_body().await?;
755 loop {
756 self.parse_head().await?;
757 if self.status == Some(Status::Continue) {
758 self.status = None;
759 } else {
760 break;
761 }
762 }
763
764 Ok(())
765 }
766
767 async fn send_body(&mut self) -> Result<()> {
768 if let Some(mut body) = self.request_body.take() {
769 io::copy(&mut body, self.transport()).await?;
770 }
771 Ok(())
772 }
773
774 fn validate_response_headers(&self) -> Result<()> {
775 let content_length = self.response_headers.has_header(ContentLength);
776
777 let transfer_encoding_chunked = self
778 .response_headers
779 .eq_ignore_ascii_case(TransferEncoding, "chunked");
780
781 if content_length && transfer_encoding_chunked {
782 Err(Error::UnexpectedHeader("content-length"))
783 } else {
784 Ok(())
785 }
786 }
787
788 fn is_keep_alive(&self) -> bool {
789 self.response_headers
790 .eq_ignore_ascii_case(Connection, "keep-alive")
791 }
792
793 async fn finish_reading_body(&mut self) {
794 if self.response_body_state != ReceivedBodyState::End {
795 let body = self.response_body();
796 match body.drain().await {
797 Ok(drain) => log::debug!(
798 "drained {}",
799 Size::from_bytes(drain).format().with_base(Base::Base10)
800 ),
801 Err(e) => log::warn!("failed to drain body, {:?}", e),
802 }
803 }
804 }
805
806 async fn exec(&mut self) -> Result<()> {
807 self.finalize_headers()?;
808 self.connect_and_send_head().await?;
809 self.send_body_and_parse_head().await?;
810 Ok(())
811 }
812}
813
814impl Drop for Conn {
815 fn drop(&mut self) {
816 if !self.is_keep_alive() {
817 return;
818 }
819
820 let Some(transport) = self.transport.take() else {
821 return;
822 };
823 let Ok(Some(peer_addr)) = transport.peer_addr() else {
824 return;
825 };
826 let Some(pool) = self.pool.take() else { return };
827
828 let origin = self.url.origin();
829
830 if self.response_body_state == ReceivedBodyState::End {
831 log::trace!("response body has been read to completion, checking transport back into pool for {}", &peer_addr);
832 pool.insert(origin, PoolEntry::new(transport, None));
833 } else {
834 let content_length = self.response_content_length();
835 let buffer = std::mem::take(&mut self.buffer);
836 let response_body_state = self.response_body_state;
837 let encoding = encoding(&self.response_headers);
838 Connector::spawn(&self.config, async move {
839 let mut response_body = ReceivedBody::new(
840 content_length,
841 buffer,
842 transport,
843 response_body_state,
844 None,
845 encoding,
846 );
847
848 match io::copy(&mut response_body, io::sink()).await {
849 Ok(bytes) => {
850 let transport = response_body.take_transport().unwrap();
851 log::trace!(
852 "read {} bytes in order to recycle conn for {}",
853 bytes,
854 &peer_addr
855 );
856 pool.insert(origin, PoolEntry::new(transport, None));
857 }
858
859 Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror),
860 };
861 });
862 }
863 }
864}
865
866impl From<Conn> for Body {
867 fn from(conn: Conn) -> Body {
868 let received_body: ReceivedBody<'static, _> = conn.into();
869 received_body.into()
870 }
871}
872
873impl From<Conn> for ReceivedBody<'static, BoxedTransport> {
874 fn from(mut conn: Conn) -> Self {
875 let _ = conn.finalize_headers();
876 let origin = conn.url.origin();
877
878 let on_completion =
879 conn.pool
880 .take()
881 .map(|pool| -> Box<dyn Fn(BoxedTransport) + Send + Sync> {
882 Box::new(move |transport| {
883 pool.insert(origin.clone(), PoolEntry::new(transport, None));
884 })
885 });
886
887 ReceivedBody::new(
888 conn.response_content_length(),
889 std::mem::take(&mut conn.buffer),
890 conn.transport.take().unwrap(),
891 conn.response_body_state,
892 on_completion,
893 conn.response_encoding(),
894 )
895 }
896}
897
898impl From<Conn> for Upgrade<BoxedTransport> {
899 fn from(mut conn: Conn) -> Self {
900 Upgrade {
901 request_headers: std::mem::take(&mut conn.request_headers),
902 path: conn.url.path().to_string(),
903 method: conn.method,
904 state: StateSet::new(),
905 transport: conn.transport.take().unwrap(),
906 buffer: Some(std::mem::take(&mut conn.buffer).into()),
907 stopper: Stopper::new(),
908 }
909 }
910}
911
912impl IntoFuture for Conn {
913 type Output = Result<Conn>;
914
915 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
916
917 fn into_future(mut self) -> Self::IntoFuture {
918 Box::pin(async move {
919 self.exec().await?;
920 Ok(self)
921 })
922 }
923}
924
925impl<'conn> IntoFuture for &'conn mut Conn {
926 type Output = Result<()>;
927
928 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
929
930 fn into_future(self) -> Self::IntoFuture {
931 Box::pin(async move {
932 self.exec().await?;
933 Ok(())
934 })
935 }
936}
937
938#[derive(Debug)]
943pub struct UnexpectedStatusError(Box<Conn>);
944impl From<Conn> for UnexpectedStatusError {
945 fn from(value: Conn) -> Self {
946 Self(Box::new(value))
947 }
948}
949
950impl From<UnexpectedStatusError> for Conn {
951 fn from(value: UnexpectedStatusError) -> Self {
952 *value.0
953 }
954}
955
956impl Deref for UnexpectedStatusError {
957 type Target = Conn;
958
959 fn deref(&self) -> &Self::Target {
960 &self.0
961 }
962}
963impl DerefMut for UnexpectedStatusError {
964 fn deref_mut(&mut self) -> &mut Self::Target {
965 &mut self.0
966 }
967}
968
969impl std::error::Error for UnexpectedStatusError {}
970impl Display for UnexpectedStatusError {
971 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
972 match self.status() {
973 Some(status) => f.write_fmt(format_args!(
974 "expected a success (2xx) status code, but got {status}"
975 )),
976 None => f.write_str("expected a status code to be set, but none was"),
977 }
978 }
979}