1use crate::{Conn, Headers, Method, StateSet, Stopper};
2use futures_lite::{AsyncRead, AsyncWrite};
3use std::{
4 fmt::{self, Debug, Formatter},
5 io,
6 pin::Pin,
7 str,
8 task::{Context, Poll},
9};
10use trillium_macros::AsyncWrite;
11
12#[derive(AsyncWrite)]
24pub struct Upgrade<Transport> {
25 pub request_headers: Headers,
27 pub path: String,
29 pub method: Method,
31 pub state: StateSet,
33 #[async_write]
35 pub transport: Transport,
36 pub buffer: Option<Vec<u8>>,
40 pub stopper: Stopper,
44}
45
46impl<Transport> Upgrade<Transport> {
47 #[deprecated = "directly access the request_headers field"]
49 pub fn headers(&self) -> &Headers {
50 &self.request_headers
51 }
52
53 pub fn path(&self) -> &str {
55 match self.path.split_once('?') {
56 Some((path, _)) => path,
57 None => &self.path,
58 }
59 }
60
61 pub fn querystring(&self) -> &str {
63 self.path
64 .split_once('?')
65 .map(|(_, query)| query)
66 .unwrap_or_default()
67 }
68
69 pub fn method(&self) -> &Method {
71 &self.method
72 }
73
74 pub fn state(&self) -> &StateSet {
77 &self.state
78 }
79
80 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
84 self,
85 f: impl Fn(Transport) -> T,
86 ) -> Upgrade<T> {
87 Upgrade {
88 transport: f(self.transport),
89 path: self.path,
90 method: self.method,
91 state: self.state,
92 buffer: self.buffer,
93 request_headers: self.request_headers,
94 stopper: self.stopper,
95 }
96 }
97}
98
99impl<Transport> Debug for Upgrade<Transport> {
100 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
101 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
102 .field("request_headers", &self.request_headers)
103 .field("path", &self.path)
104 .field("method", &self.method)
105 .field(
106 "buffer",
107 &self.buffer.as_deref().map(String::from_utf8_lossy),
108 )
109 .field("stopper", &self.stopper)
110 .field("state", &self.state)
111 .field("transport", &"..")
112 .finish()
113 }
114}
115
116impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
117 fn from(conn: Conn<Transport>) -> Self {
118 let Conn {
119 request_headers,
120 path,
121 method,
122 state,
123 transport,
124 buffer,
125 stopper,
126 ..
127 } = conn;
128
129 Self {
130 request_headers,
131 path,
132 method,
133 state,
134 transport,
135 buffer: if buffer.is_empty() {
136 None
137 } else {
138 Some(buffer.into())
139 },
140 stopper,
141 }
142 }
143}
144
145impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
146 fn poll_read(
147 mut self: Pin<&mut Self>,
148 cx: &mut Context<'_>,
149 buf: &mut [u8],
150 ) -> Poll<io::Result<usize>> {
151 match self.buffer.take() {
152 Some(mut buffer) if !buffer.is_empty() => {
153 let len = buffer.len();
154 if len > buf.len() {
155 log::trace!(
156 "have {} bytes of pending data but can only use {}",
157 len,
158 buf.len()
159 );
160 let remaining = buffer.split_off(buf.len());
161 buf.copy_from_slice(&buffer[..]);
162 self.buffer = Some(remaining);
163 Poll::Ready(Ok(buf.len()))
164 } else {
165 log::trace!("have {} bytes of pending data, using all of it", len);
166 buf[..len].copy_from_slice(&buffer);
167 self.buffer = None;
168 match Pin::new(&mut self.transport).poll_read(cx, &mut buf[len..]) {
169 Poll::Ready(Ok(e)) => Poll::Ready(Ok(e + len)),
170 Poll::Pending => Poll::Ready(Ok(len)),
171 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
172 }
173 }
174 }
175
176 _ => Pin::new(&mut self.transport).poll_read(cx, buf),
177 }
178 }
179}