xitca_web/handler/types/
cookie.rs

1//! type extractor and responder for cookies.
2
3use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
4
5pub use cookie::{Cookie, Key, ParseError};
6
7use cookie::CookieJar as _CookieJar;
8
9use crate::{
10    WebContext,
11    body::ResponseBody,
12    error::{Error, ErrorStatus, ExtensionNotFound, HeaderNotFound, error_from_service, forward_blank_bad_request},
13    handler::{FromRequest, Responder},
14    http::{
15        WebResponse,
16        header::ToStrError,
17        header::{COOKIE, HeaderValue, SET_COOKIE},
18    },
19};
20
21macro_rules! key_impl {
22    ($key: tt) => {
23        impl $key {
24            /// generates signing/encryption keys from a secure, random source.
25            /// see [Key] for further detail.
26            #[inline]
27            pub fn generate() -> Self {
28                Self(Key::generate())
29            }
30        }
31
32        impl From<$key> for Key {
33            fn from(key: $key) -> Self {
34                key.0
35            }
36        }
37
38        impl From<Key> for $key {
39            fn from(key: Key) -> Self {
40                Self(key)
41            }
42        }
43    };
44}
45
46/// an extractor type wrapping around [Key] hinting itself can be extracted from
47/// application state. See [App::with_state] for compile time state management.
48///
49/// [App::with_state]: crate::App::with_state
50#[derive(Clone, Debug)]
51pub struct StateKey(Key);
52
53key_impl!(StateKey);
54
55impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for StateKey
56where
57    C: Borrow<Self>,
58{
59    type Type<'b> = Self;
60    type Error = Error;
61
62    #[inline]
63    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
64        Ok(ctx.state().borrow().clone())
65    }
66}
67
68/// an extractor type wrapping around [Key] hinting itself can be extracted from
69/// request extensions. See [WebRequest::extensions] for run time state management.
70///
71/// [WebRequest::extensions]: crate::http::WebRequest::extensions
72#[derive(Clone, Debug)]
73pub struct ExtensionKey(Key);
74
75key_impl!(ExtensionKey);
76
77impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for ExtensionKey {
78    type Type<'b> = Self;
79    type Error = Error;
80
81    #[inline]
82    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
83        ctx.req()
84            .extensions()
85            .get::<Self>()
86            .cloned()
87            .ok_or_else(|| Error::from(ExtensionNotFound::from_type::<Self>()))
88    }
89}
90
91/// container of cookies extracted from request.
92pub struct CookieJar<K = Plain> {
93    jar: _CookieJar,
94    key: K,
95}
96
97impl CookieJar {
98    /// construct a new cookie container with no encryption.
99    pub fn plain() -> Self {
100        Self {
101            jar: _CookieJar::new(),
102            key: Plain,
103        }
104    }
105
106    /// get cookie with given key name
107    #[inline]
108    pub fn get(&self, name: &str) -> Option<&Cookie> {
109        self.jar.get(name)
110    }
111
112    /// add cookie to container.
113    #[inline]
114    pub fn add<C>(&mut self, cookie: C)
115    where
116        C: Into<Cookie<'static>>,
117    {
118        self.jar.add(cookie)
119    }
120
121    /// remove cookie from container.
122    #[inline]
123    pub fn remove<C>(&mut self, cookie: C)
124    where
125        C: Into<Cookie<'static>>,
126    {
127        self.jar.remove(cookie)
128    }
129}
130
131#[doc(hidden)]
132pub struct Plain;
133
134impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for Plain {
135    type Type<'b> = Self;
136    type Error = Error;
137
138    #[inline]
139    async fn from_request(_: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
140        Ok(Self)
141    }
142}
143
144macro_rules! cookie_variant {
145    ($variant: tt, $method: tt, $method_mut: tt) => {
146        /// encrypted cookie container type.
147        /// must annotate the generic type param with types that can provide the key
148        /// for encryption. See [StateKey] and [ExtensionKey] for detail.
149        pub struct $variant<K> {
150            key: Key,
151            _key: PhantomData<fn(K)>,
152        }
153
154        impl<K> Deref for $variant<K> {
155            type Target = Key;
156
157            fn deref(&self) -> &Self::Target {
158                &self.key
159            }
160        }
161
162        impl CookieJar {
163            pub fn $method<K>(key: K) -> CookieJar<$variant<K>>
164            where
165                K: Into<Key>,
166            {
167                CookieJar {
168                    jar: _CookieJar::new(),
169                    key: $variant {
170                        key: key.into(),
171                        _key: PhantomData,
172                    },
173                }
174            }
175        }
176
177        impl<K> CookieJar<$variant<K>> {
178            #[inline]
179            pub fn get(&self, name: &str) -> Option<Cookie> {
180                self.jar.$method(&self.key).get(name)
181            }
182
183            #[inline]
184            pub fn add<C>(&mut self, cookie: C)
185            where
186                C: Into<Cookie<'static>>,
187            {
188                self.jar.$method_mut(&self.key).add(cookie)
189            }
190
191            #[inline]
192            pub fn remove<C>(&mut self, cookie: C)
193            where
194                C: Into<Cookie<'static>>,
195            {
196                self.jar.$method_mut(&self.key).remove(cookie)
197            }
198        }
199
200        impl<'a, 'r, C, B, K> FromRequest<'a, WebContext<'r, C, B>> for $variant<K>
201        where
202            K: for<'a2, 'r2> FromRequest<'a2, WebContext<'r2, C, B>, Error = Error> + Into<Key>,
203        {
204            type Type<'b> = Self;
205            type Error = Error;
206
207            #[inline]
208            async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
209                K::from_request(ctx).await.map(|key| $variant {
210                    key: key.into(),
211                    _key: PhantomData,
212                })
213            }
214        }
215    };
216}
217
218cookie_variant!(Private, private, private_mut);
219cookie_variant!(Signed, signed, signed_mut);
220
221impl<'a, 'r, C, B, K> FromRequest<'a, WebContext<'r, C, B>> for CookieJar<K>
222where
223    K: for<'a2, 'r2> FromRequest<'a2, WebContext<'r2, C, B>, Error = Error>,
224{
225    type Type<'b> = CookieJar<K>;
226    type Error = Error;
227
228    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
229        let key = K::from_request(ctx).await?;
230
231        let mut jar = _CookieJar::new();
232
233        let headers = ctx.req().headers();
234
235        if !headers.contains_key(COOKIE) {
236            return Err(Error::from(HeaderNotFound(COOKIE)));
237        }
238
239        for val in headers.get_all(COOKIE) {
240            for val in val.to_str()?.split(';') {
241                let cookie = Cookie::parse_encoded(val.to_owned())?;
242                jar.add_original(cookie);
243            }
244        }
245
246        Ok(CookieJar { jar, key })
247    }
248}
249
250error_from_service!(ToStrError);
251forward_blank_bad_request!(ToStrError);
252
253error_from_service!(ParseError);
254forward_blank_bad_request!(ParseError);
255
256impl<'r, C, B, K> Responder<WebContext<'r, C, B>> for CookieJar<K> {
257    type Response = WebResponse;
258    type Error = Error;
259
260    async fn respond(self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
261        let res = ctx.into_response(ResponseBody::empty());
262        Responder::<WebContext<'r, C, B>>::map(self, res)
263    }
264
265    fn map(self, mut res: Self::Response) -> Result<Self::Response, Self::Error> {
266        let headers = res.headers_mut();
267        for cookie in self.jar.delta() {
268            let value = HeaderValue::try_from(cookie.encoded().to_string()).map_err(|_| ErrorStatus::internal())?;
269            headers.append(SET_COOKIE, value);
270        }
271        Ok(res)
272    }
273}
274
275#[cfg(test)]
276mod test {
277    use xitca_unsafe_collection::futures::NowOrPanic;
278
279    use super::*;
280
281    #[test]
282    fn cookie() {
283        let mut ctx = WebContext::new_test(&());
284        let mut ctx = ctx.as_web_ctx();
285
286        let mut jar = CookieJar::plain();
287        jar.add(("foo", "bar"));
288
289        let cookie = jar
290            .respond(ctx.reborrow())
291            .now_or_panic()
292            .unwrap()
293            .headers_mut()
294            .remove(SET_COOKIE)
295            .unwrap();
296
297        ctx.req_mut().headers_mut().insert(COOKIE, cookie);
298
299        let mut jar: CookieJar = CookieJar::from_request(&ctx).now_or_panic().unwrap();
300
301        let val = jar.get("foo").unwrap();
302        assert_eq!(val.name(), "foo");
303        assert_eq!(val.value(), "bar");
304
305        jar.add(("996", "251"));
306
307        let res = CookieJar::respond(jar, ctx).now_or_panic().unwrap();
308
309        let header = res.headers().get(SET_COOKIE).unwrap();
310        assert_eq!(header.to_str().unwrap(), "996=251");
311    }
312
313    #[derive(Clone)]
314    struct MyKey(Key);
315
316    impl From<MyKey> for Key {
317        fn from(value: MyKey) -> Self {
318            value.0
319        }
320    }
321
322    impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for MyKey {
323        type Type<'b> = MyKey;
324        type Error = Error;
325
326        async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
327            Ok(ctx.req().extensions().get::<MyKey>().unwrap().clone())
328        }
329    }
330
331    #[test]
332    fn private_cookie() {
333        let mut ctx = WebContext::new_test(&());
334        let mut ctx = ctx.as_web_ctx();
335
336        let key = Key::generate();
337
338        let mut jar = CookieJar::private(key.clone());
339        jar.add(("foo", "bar"));
340
341        let cookie = jar
342            .respond(ctx.reborrow())
343            .now_or_panic()
344            .unwrap()
345            .headers_mut()
346            .remove(SET_COOKIE)
347            .unwrap();
348
349        ctx.req_mut().headers_mut().insert(COOKIE, cookie);
350        ctx.req_mut().extensions_mut().insert(MyKey(key));
351
352        let jar = CookieJar::<Private<MyKey>>::from_request(&ctx).now_or_panic().unwrap();
353
354        let val = jar.get("foo").unwrap();
355        assert_eq!(val.name(), "foo");
356        assert_eq!(val.value(), "bar");
357    }
358
359    #[test]
360    fn signed_cookie() {
361        let mut ctx = WebContext::new_test(&());
362        let mut ctx = ctx.as_web_ctx();
363
364        let key = Key::generate();
365
366        let mut jar = CookieJar::signed(key.clone());
367        jar.add(("foo", "bar"));
368
369        let cookie = jar
370            .respond(ctx.reborrow())
371            .now_or_panic()
372            .unwrap()
373            .headers_mut()
374            .remove(SET_COOKIE)
375            .unwrap();
376
377        ctx.req_mut().headers_mut().insert(COOKIE, cookie);
378        ctx.req_mut().extensions_mut().insert(MyKey(key));
379
380        let jar = CookieJar::<Signed<MyKey>>::from_request(&ctx).now_or_panic().unwrap();
381
382        let val = jar.get("foo").unwrap();
383        assert_eq!(val.name(), "foo");
384        assert_eq!(val.value(), "bar");
385    }
386}