thruster_grpc/
context.rs

1use async_trait::async_trait;
2use bytes::{BufMut, BytesMut};
3use hyper::body::Bytes;
4use hyper::header::{HeaderName, HeaderValue};
5use hyper::{HeaderMap, Response, StatusCode};
6use prost::{DecodeError, Message};
7use std::collections::HashMap;
8use std::convert::TryInto;
9use std::str;
10use thruster::context::hyper_request::HyperRequest;
11use thruster::middleware::query_params::HasQueryParams;
12use thruster::Context;
13use tokio_stream::StreamExt;
14
15use crate::body::ProtoBody;
16use crate::context::ProtoContext as Ctx;
17
18/// 4 accounts for content-type, server, grpc-status, and trailers.
19const DEFAULT_HEADER_CAPACITY: usize = 4;
20
21pub fn generate_context(request: HyperRequest, _state: &(), _path: &str) -> Ctx<()> {
22    Ctx::new(request, ())
23}
24
25pub enum SameSite {
26    #[allow(dead_code)]
27    Strict,
28    #[allow(dead_code)]
29    Lax,
30}
31
32pub struct CookieOptions {
33    pub domain: String,
34    pub path: String,
35    pub expires: u64,
36    pub http_only: bool,
37    pub max_age: u64,
38    pub secure: bool,
39    pub signed: bool,
40    pub same_site: SameSite,
41}
42
43impl Default for CookieOptions {
44    fn default() -> CookieOptions {
45        CookieOptions {
46            domain: "".to_owned(),
47            path: "/".to_owned(),
48            expires: 0,
49            http_only: false,
50            max_age: 0,
51            secure: false,
52            signed: false,
53            same_site: SameSite::Strict,
54        }
55    }
56}
57
58#[derive(Default)]
59pub struct ProtoContext<T> {
60    pub body: Option<ProtoBody>,
61    pub query_params: Option<HashMap<String, String>>,
62    pub status: u16,
63    pub hyper_request: Option<HyperRequest>,
64    pub extra: T,
65    http_version: hyper::Version,
66    headers: HeaderMap,
67}
68
69impl<T> ProtoContext<T> {
70    pub fn new(req: HyperRequest, extra: T) -> ProtoContext<T> {
71        let mut ctx = ProtoContext {
72            body: None,
73            query_params: None,
74            headers: HeaderMap::with_capacity(DEFAULT_HEADER_CAPACITY),
75            status: 200,
76            hyper_request: Some(req),
77            extra,
78            http_version: hyper::Version::HTTP_11,
79        };
80
81        ctx.set("Server", "Thruster");
82
83        ctx
84    }
85
86    ///
87    /// Set the response status code
88    ///
89    #[allow(dead_code)]
90    pub fn status(&mut self, code: u32) {
91        self.status = code.try_into().unwrap();
92    }
93
94    ///
95    /// Set the response `Content-Type`. A shortcode for
96    ///
97    /// ```ignore
98    /// ctx.set("Content-Type", "some-val");
99    /// ```
100    ///
101    #[allow(dead_code)]
102    pub fn content_type(&mut self, c_type: &str) {
103        self.set("Content-Type", c_type);
104    }
105
106    ///
107    /// Set up a redirect, will default to 302, but can be changed after
108    /// the fact.
109    ///
110    /// ```ignore
111    /// ctx.set("Location", "/some-path");
112    /// ctx.status(302);
113    /// ```
114    ///
115    #[allow(dead_code)]
116    pub fn redirect(&mut self, destination: &str) {
117        self.status(302);
118
119        self.set("Location", destination);
120    }
121
122    ///
123    /// Sets a cookie on the response
124    ///
125    #[allow(dead_code)]
126    pub fn cookie(&mut self, name: &str, value: &str, options: &CookieOptions) {
127        let cookie_value = match self.headers.get("Set-Cookie") {
128            Some(val) => format!(
129                "{}, {}",
130                val.to_str().unwrap_or_else(|_| ""),
131                self.cookify_options(name, value, &options)
132            ),
133            None => self.cookify_options(name, value, &options),
134        };
135
136        self.set("Set-Cookie", &cookie_value);
137    }
138
139    #[allow(dead_code)]
140    fn cookify_options(&self, name: &str, value: &str, options: &CookieOptions) -> String {
141        let mut pieces = vec![format!("Path={}", options.path)];
142
143        if options.expires > 0 {
144            pieces.push(format!("Expires={}", options.expires));
145        }
146
147        if options.max_age > 0 {
148            pieces.push(format!("Max-Age={}", options.max_age));
149        }
150
151        if !options.domain.is_empty() {
152            pieces.push(format!("Domain={}", options.domain));
153        }
154
155        if options.secure {
156            pieces.push("Secure".to_owned());
157        }
158
159        if options.http_only {
160            pieces.push("HttpOnly".to_owned());
161        }
162
163        match options.same_site {
164            SameSite::Strict => pieces.push("SameSite=Strict".to_owned()),
165            SameSite::Lax => pieces.push("SameSite=Lax".to_owned()),
166        };
167
168        format!("{}={}; {}", name, value, pieces.join(", "))
169    }
170
171    #[allow(dead_code)]
172    pub fn set_http2(&mut self) {
173        self.http_version = hyper::Version::HTTP_2;
174    }
175
176    #[allow(dead_code)]
177    pub fn set_http11(&mut self) {
178        self.http_version = hyper::Version::HTTP_11;
179    }
180
181    #[allow(dead_code)]
182    pub fn set_http10(&mut self) {
183        self.http_version = hyper::Version::HTTP_10;
184    }
185
186    pub fn set_proto_status(&mut self, status: u16) {
187        self.headers
188            .insert("grpc-status", format!("{}", status).parse().unwrap());
189    }
190}
191
192impl<T> Context for ProtoContext<T> {
193    type Response = Response<ProtoBody>;
194
195    fn get_response(mut self) -> Self::Response {
196        let mut body = self.body.take().unwrap();
197        body.set_headers(self.headers.clone());
198        let mut response = Response::new(body);
199
200        *response.status_mut() = StatusCode::from_u16(self.status).unwrap();
201        *response.headers_mut() = self.headers;
202        *response.version_mut() = self.http_version;
203
204        response
205    }
206
207    fn set_body(&mut self, body: Vec<u8>) {
208        self.body.replace(ProtoBody::from_bytes(Bytes::from(body)));
209    }
210
211    fn set_body_bytes(&mut self, bytes: Bytes) {
212        self.body.replace(ProtoBody::from_bytes(bytes));
213    }
214
215    fn route(&self) -> &str {
216        let uri = self.hyper_request.as_ref().unwrap().request.uri();
217
218        match uri.path_and_query() {
219            Some(val) => val.as_str(),
220            None => uri.path(),
221        }
222    }
223
224    fn set(&mut self, key: &str, value: &str) {
225        self.headers.insert(
226            HeaderName::from_bytes(key.as_bytes()).unwrap(),
227            HeaderValue::from_str(value).unwrap(),
228        );
229    }
230
231    fn remove(&mut self, key: &str) {
232        self.headers.remove(key);
233    }
234}
235
236impl<T> HasQueryParams for ProtoContext<T> {
237    fn set_query_params(&mut self, query_params: HashMap<String, String>) {
238        self.query_params = Some(query_params);
239    }
240}
241
242impl<T> Clone for ProtoContext<T> {
243    fn clone(&self) -> Self {
244        panic!("Do not use, just for internals.");
245    }
246}
247
248#[async_trait]
249pub trait ProtoContextExt<T> {
250    async fn proto<M: Message + std::default::Default>(&mut self, message: M);
251    async fn get_proto<M: Message + std::default::Default>(&mut self) -> Result<M, DecodeError>;
252}
253
254#[async_trait]
255impl<T: Send> ProtoContextExt<T> for ProtoContext<T> {
256    async fn proto<M: Message + std::default::Default>(&mut self, message: M) {
257        self.set("content-type", "application/grpc");
258        self.set("grpc-status", "0");
259        self.set("trailers", "grpc-status");
260        self.set_http2();
261
262        let mut buf = BytesMut::new();
263        buf.reserve(5);
264        buf.put(&b"00000"[..]);
265
266        let _ = message.encode(&mut buf);
267
268        let len = buf.len() - 5;
269        assert!(len <= std::u32::MAX as usize);
270        {
271            let mut buf = &mut buf[..5];
272            buf.put_u8(0); // byte must be 0, reserve doesn't auto-zero
273            buf.put_u32(len as u32);
274        }
275        let buf = buf.split_to(len + 5).freeze();
276
277        self.body = Some(ProtoBody::from_bytes(buf));
278    }
279
280    async fn get_proto<M: Message + std::default::Default>(&mut self) -> Result<M, DecodeError> {
281        let hyper_request = self.hyper_request.take().unwrap().request;
282
283        let mut results = vec![];
284        let mut body = hyper_request.into_body();
285        while let Some(Ok(chunk)) = body.next().await {
286            results.put(chunk);
287        }
288
289        M::decode(&results[5..])
290    }
291}