1use crate::{
2 Buffer, Conn, Headers, HttpContext, Method, TypeSet, Version, h3::H3Connection,
3 received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8 borrow::Cow,
9 fmt::{self, Debug, Formatter},
10 io,
11 net::IpAddr,
12 pin::Pin,
13 str,
14 sync::Arc,
15 task::{self, Poll},
16};
17use trillium_macros::AsyncWrite;
18
19#[derive(AsyncWrite, Fieldwork)]
27#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
28pub struct Upgrade<Transport> {
29 request_headers: Headers,
31
32 #[field(get = false)]
34 path: Cow<'static, str>,
35
36 #[field(copy)]
38 method: Method,
39
40 state: TypeSet,
42
43 #[async_write]
45 transport: Transport,
46
47 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
52 buffer: Buffer,
53
54 #[field(deref = false)]
56 context: Arc<HttpContext>,
57
58 #[field(copy)]
60 peer_ip: Option<IpAddr>,
61
62 authority: Option<Cow<'static, str>>,
64
65 scheme: Option<Cow<'static, str>>,
67
68 #[field(
70 get(deref = false),
71 get_mut = false,
72 set = false,
73 with = false,
74 into_field = false,
75 take = false
76 )]
77 h3_connection: Option<Arc<H3Connection>>,
78
79 protocol: Option<Cow<'static, str>>,
81
82 #[field = "http_version"]
84 version: Version,
85
86 secure: bool,
88}
89
90impl<Transport> Upgrade<Transport> {
91 #[doc(hidden)]
92 pub fn new(
93 request_headers: Headers,
94 path: impl Into<Cow<'static, str>>,
95 method: Method,
96 transport: Transport,
97 buffer: Buffer,
98 version: Version,
99 ) -> Self {
100 Self {
101 request_headers,
102 path: path.into(),
103 method,
104 transport,
105 buffer,
106 state: TypeSet::new(),
107 context: Arc::default(),
108 peer_ip: None,
109 authority: None,
110 scheme: None,
111 h3_connection: None,
112 protocol: None,
113 secure: false,
114 version,
115 }
116 }
117
118 pub fn take_buffer(&mut self) -> Vec<u8> {
120 std::mem::take(&mut self.buffer).into()
121 }
122
123 #[doc(hidden)]
124 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
125 (&mut self.buffer, &mut self.transport)
126 }
127
128 pub fn shared_state(&self) -> &TypeSet {
130 self.context.shared_state()
131 }
132
133 pub fn path(&self) -> &str {
135 match self.path.split_once('?') {
136 Some((path, _)) => path,
137 None => &self.path,
138 }
139 }
140
141 pub fn querystring(&self) -> &str {
143 self.path
144 .split_once('?')
145 .map(|(_, query)| query)
146 .unwrap_or_default()
147 }
148
149 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
153 self,
154 f: impl Fn(Transport) -> T,
155 ) -> Upgrade<T> {
156 Upgrade {
157 transport: f(self.transport),
158 path: self.path,
159 method: self.method,
160 state: self.state,
161 buffer: self.buffer,
162 request_headers: self.request_headers,
163 context: self.context,
164 peer_ip: self.peer_ip,
165 authority: self.authority,
166 scheme: self.scheme,
167 h3_connection: self.h3_connection,
168 protocol: self.protocol,
169 version: self.version,
170 secure: self.secure,
171 }
172 }
173}
174
175impl<Transport> Debug for Upgrade<Transport> {
176 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
177 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
178 .field("request_headers", &self.request_headers)
179 .field("path", &self.path)
180 .field("method", &self.method)
181 .field("buffer", &self.buffer)
182 .field("context", &self.context)
183 .field("state", &self.state)
184 .field("transport", &format_args!(".."))
185 .field("peer_ip", &self.peer_ip)
186 .field("authority", &self.authority)
187 .field("scheme", &self.scheme)
188 .field("h3_connection", &self.h3_connection)
189 .field("protocol", &self.protocol)
190 .field("version", &self.version)
191 .field("secure", &self.secure)
192 .finish()
193 }
194}
195
196impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
197 fn from(conn: Conn<Transport>) -> Self {
198 let Conn {
199 request_headers,
200 path,
201 method,
202 state,
203 transport,
204 buffer,
205 context,
206 peer_ip,
207 authority,
208 scheme,
209 h3_connection,
210 protocol,
211 version,
212 secure,
213 ..
214 } = conn;
215
216 Self {
217 request_headers,
218 path,
219 method,
220 state,
221 transport,
222 buffer,
223 context,
224 peer_ip,
225 authority,
226 scheme,
227 h3_connection,
228 protocol,
229 version,
230 secure,
231 }
232 }
233}
234
235impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
236 fn poll_read(
237 mut self: Pin<&mut Self>,
238 cx: &mut task::Context<'_>,
239 buf: &mut [u8],
240 ) -> Poll<io::Result<usize>> {
241 let Self {
242 transport, buffer, ..
243 } = &mut *self;
244 read_buffered(buffer, transport, cx, buf)
245 }
246}