1use async_trait::async_trait;
2use bytes::{BufMut, BytesMut};
3use hyper::body::Bytes;
4use hyper::header::{HeaderName, HeaderValue};
5use hyper::{HeaderMap, Response, StatusCode};
6use prost::{DecodeError, Message};
7use std::collections::HashMap;
8use std::convert::TryInto;
9use std::str;
10use thruster::context::hyper_request::HyperRequest;
11use thruster::middleware::query_params::HasQueryParams;
12use thruster::Context;
13use tokio_stream::StreamExt;
14
15use crate::body::ProtoBody;
16use crate::context::ProtoContext as Ctx;
17
18const DEFAULT_HEADER_CAPACITY: usize = 4;
20
21pub fn generate_context(request: HyperRequest, _state: &(), _path: &str) -> Ctx<()> {
22 Ctx::new(request, ())
23}
24
25pub enum SameSite {
26 #[allow(dead_code)]
27 Strict,
28 #[allow(dead_code)]
29 Lax,
30}
31
32pub struct CookieOptions {
33 pub domain: String,
34 pub path: String,
35 pub expires: u64,
36 pub http_only: bool,
37 pub max_age: u64,
38 pub secure: bool,
39 pub signed: bool,
40 pub same_site: SameSite,
41}
42
43impl Default for CookieOptions {
44 fn default() -> CookieOptions {
45 CookieOptions {
46 domain: "".to_owned(),
47 path: "/".to_owned(),
48 expires: 0,
49 http_only: false,
50 max_age: 0,
51 secure: false,
52 signed: false,
53 same_site: SameSite::Strict,
54 }
55 }
56}
57
58#[derive(Default)]
59pub struct ProtoContext<T> {
60 pub body: Option<ProtoBody>,
61 pub query_params: Option<HashMap<String, String>>,
62 pub status: u16,
63 pub hyper_request: Option<HyperRequest>,
64 pub extra: T,
65 http_version: hyper::Version,
66 headers: HeaderMap,
67}
68
69impl<T> ProtoContext<T> {
70 pub fn new(req: HyperRequest, extra: T) -> ProtoContext<T> {
71 let mut ctx = ProtoContext {
72 body: None,
73 query_params: None,
74 headers: HeaderMap::with_capacity(DEFAULT_HEADER_CAPACITY),
75 status: 200,
76 hyper_request: Some(req),
77 extra,
78 http_version: hyper::Version::HTTP_11,
79 };
80
81 ctx.set("Server", "Thruster");
82
83 ctx
84 }
85
86 #[allow(dead_code)]
90 pub fn status(&mut self, code: u32) {
91 self.status = code.try_into().unwrap();
92 }
93
94 #[allow(dead_code)]
102 pub fn content_type(&mut self, c_type: &str) {
103 self.set("Content-Type", c_type);
104 }
105
106 #[allow(dead_code)]
116 pub fn redirect(&mut self, destination: &str) {
117 self.status(302);
118
119 self.set("Location", destination);
120 }
121
122 #[allow(dead_code)]
126 pub fn cookie(&mut self, name: &str, value: &str, options: &CookieOptions) {
127 let cookie_value = match self.headers.get("Set-Cookie") {
128 Some(val) => format!(
129 "{}, {}",
130 val.to_str().unwrap_or_else(|_| ""),
131 self.cookify_options(name, value, &options)
132 ),
133 None => self.cookify_options(name, value, &options),
134 };
135
136 self.set("Set-Cookie", &cookie_value);
137 }
138
139 #[allow(dead_code)]
140 fn cookify_options(&self, name: &str, value: &str, options: &CookieOptions) -> String {
141 let mut pieces = vec![format!("Path={}", options.path)];
142
143 if options.expires > 0 {
144 pieces.push(format!("Expires={}", options.expires));
145 }
146
147 if options.max_age > 0 {
148 pieces.push(format!("Max-Age={}", options.max_age));
149 }
150
151 if !options.domain.is_empty() {
152 pieces.push(format!("Domain={}", options.domain));
153 }
154
155 if options.secure {
156 pieces.push("Secure".to_owned());
157 }
158
159 if options.http_only {
160 pieces.push("HttpOnly".to_owned());
161 }
162
163 match options.same_site {
164 SameSite::Strict => pieces.push("SameSite=Strict".to_owned()),
165 SameSite::Lax => pieces.push("SameSite=Lax".to_owned()),
166 };
167
168 format!("{}={}; {}", name, value, pieces.join(", "))
169 }
170
171 #[allow(dead_code)]
172 pub fn set_http2(&mut self) {
173 self.http_version = hyper::Version::HTTP_2;
174 }
175
176 #[allow(dead_code)]
177 pub fn set_http11(&mut self) {
178 self.http_version = hyper::Version::HTTP_11;
179 }
180
181 #[allow(dead_code)]
182 pub fn set_http10(&mut self) {
183 self.http_version = hyper::Version::HTTP_10;
184 }
185
186 pub fn set_proto_status(&mut self, status: u16) {
187 self.headers
188 .insert("grpc-status", format!("{}", status).parse().unwrap());
189 }
190}
191
192impl<T> Context for ProtoContext<T> {
193 type Response = Response<ProtoBody>;
194
195 fn get_response(mut self) -> Self::Response {
196 let mut body = self.body.take().unwrap();
197 body.set_headers(self.headers.clone());
198 let mut response = Response::new(body);
199
200 *response.status_mut() = StatusCode::from_u16(self.status).unwrap();
201 *response.headers_mut() = self.headers;
202 *response.version_mut() = self.http_version;
203
204 response
205 }
206
207 fn set_body(&mut self, body: Vec<u8>) {
208 self.body.replace(ProtoBody::from_bytes(Bytes::from(body)));
209 }
210
211 fn set_body_bytes(&mut self, bytes: Bytes) {
212 self.body.replace(ProtoBody::from_bytes(bytes));
213 }
214
215 fn route(&self) -> &str {
216 let uri = self.hyper_request.as_ref().unwrap().request.uri();
217
218 match uri.path_and_query() {
219 Some(val) => val.as_str(),
220 None => uri.path(),
221 }
222 }
223
224 fn set(&mut self, key: &str, value: &str) {
225 self.headers.insert(
226 HeaderName::from_bytes(key.as_bytes()).unwrap(),
227 HeaderValue::from_str(value).unwrap(),
228 );
229 }
230
231 fn remove(&mut self, key: &str) {
232 self.headers.remove(key);
233 }
234}
235
236impl<T> HasQueryParams for ProtoContext<T> {
237 fn set_query_params(&mut self, query_params: HashMap<String, String>) {
238 self.query_params = Some(query_params);
239 }
240}
241
242impl<T> Clone for ProtoContext<T> {
243 fn clone(&self) -> Self {
244 panic!("Do not use, just for internals.");
245 }
246}
247
248#[async_trait]
249pub trait ProtoContextExt<T> {
250 async fn proto<M: Message + std::default::Default>(&mut self, message: M);
251 async fn get_proto<M: Message + std::default::Default>(&mut self) -> Result<M, DecodeError>;
252}
253
254#[async_trait]
255impl<T: Send> ProtoContextExt<T> for ProtoContext<T> {
256 async fn proto<M: Message + std::default::Default>(&mut self, message: M) {
257 self.set("content-type", "application/grpc");
258 self.set("grpc-status", "0");
259 self.set("trailers", "grpc-status");
260 self.set_http2();
261
262 let mut buf = BytesMut::new();
263 buf.reserve(5);
264 buf.put(&b"00000"[..]);
265
266 let _ = message.encode(&mut buf);
267
268 let len = buf.len() - 5;
269 assert!(len <= std::u32::MAX as usize);
270 {
271 let mut buf = &mut buf[..5];
272 buf.put_u8(0); buf.put_u32(len as u32);
274 }
275 let buf = buf.split_to(len + 5).freeze();
276
277 self.body = Some(ProtoBody::from_bytes(buf));
278 }
279
280 async fn get_proto<M: Message + std::default::Default>(&mut self) -> Result<M, DecodeError> {
281 let hyper_request = self.hyper_request.take().unwrap().request;
282
283 let mut results = vec![];
284 let mut body = hyper_request.into_body();
285 while let Some(Ok(chunk)) = body.next().await {
286 results.put(chunk);
287 }
288
289 M::decode(&results[5..])
290 }
291}