1use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
4
5pub use cookie::{Cookie, Key, ParseError};
6
7use cookie::CookieJar as _CookieJar;
8
9use crate::{
10 WebContext,
11 body::ResponseBody,
12 error::{Error, ErrorStatus, ExtensionNotFound, HeaderNotFound, error_from_service, forward_blank_bad_request},
13 handler::{FromRequest, Responder},
14 http::{
15 WebResponse,
16 header::ToStrError,
17 header::{COOKIE, HeaderValue, SET_COOKIE},
18 },
19};
20
21macro_rules! key_impl {
22 ($key: tt) => {
23 impl $key {
24 #[inline]
27 pub fn generate() -> Self {
28 Self(Key::generate())
29 }
30 }
31
32 impl From<$key> for Key {
33 fn from(key: $key) -> Self {
34 key.0
35 }
36 }
37
38 impl From<Key> for $key {
39 fn from(key: Key) -> Self {
40 Self(key)
41 }
42 }
43 };
44}
45
46#[derive(Clone, Debug)]
51pub struct StateKey(Key);
52
53key_impl!(StateKey);
54
55impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for StateKey
56where
57 C: Borrow<Self>,
58{
59 type Type<'b> = Self;
60 type Error = Error;
61
62 #[inline]
63 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
64 Ok(ctx.state().borrow().clone())
65 }
66}
67
68#[derive(Clone, Debug)]
73pub struct ExtensionKey(Key);
74
75key_impl!(ExtensionKey);
76
77impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for ExtensionKey {
78 type Type<'b> = Self;
79 type Error = Error;
80
81 #[inline]
82 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
83 ctx.req()
84 .extensions()
85 .get::<Self>()
86 .cloned()
87 .ok_or_else(|| Error::from(ExtensionNotFound::from_type::<Self>()))
88 }
89}
90
91pub struct CookieJar<K = Plain> {
93 jar: _CookieJar,
94 key: K,
95}
96
97impl CookieJar {
98 pub fn plain() -> Self {
100 Self {
101 jar: _CookieJar::new(),
102 key: Plain,
103 }
104 }
105
106 #[inline]
108 pub fn get(&self, name: &str) -> Option<&Cookie> {
109 self.jar.get(name)
110 }
111
112 #[inline]
114 pub fn add<C>(&mut self, cookie: C)
115 where
116 C: Into<Cookie<'static>>,
117 {
118 self.jar.add(cookie)
119 }
120
121 #[inline]
123 pub fn remove<C>(&mut self, cookie: C)
124 where
125 C: Into<Cookie<'static>>,
126 {
127 self.jar.remove(cookie)
128 }
129}
130
131#[doc(hidden)]
132pub struct Plain;
133
134impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for Plain {
135 type Type<'b> = Self;
136 type Error = Error;
137
138 #[inline]
139 async fn from_request(_: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
140 Ok(Self)
141 }
142}
143
144macro_rules! cookie_variant {
145 ($variant: tt, $method: tt, $method_mut: tt) => {
146 pub struct $variant<K> {
150 key: Key,
151 _key: PhantomData<fn(K)>,
152 }
153
154 impl<K> Deref for $variant<K> {
155 type Target = Key;
156
157 fn deref(&self) -> &Self::Target {
158 &self.key
159 }
160 }
161
162 impl CookieJar {
163 pub fn $method<K>(key: K) -> CookieJar<$variant<K>>
164 where
165 K: Into<Key>,
166 {
167 CookieJar {
168 jar: _CookieJar::new(),
169 key: $variant {
170 key: key.into(),
171 _key: PhantomData,
172 },
173 }
174 }
175 }
176
177 impl<K> CookieJar<$variant<K>> {
178 #[inline]
179 pub fn get(&self, name: &str) -> Option<Cookie> {
180 self.jar.$method(&self.key).get(name)
181 }
182
183 #[inline]
184 pub fn add<C>(&mut self, cookie: C)
185 where
186 C: Into<Cookie<'static>>,
187 {
188 self.jar.$method_mut(&self.key).add(cookie)
189 }
190
191 #[inline]
192 pub fn remove<C>(&mut self, cookie: C)
193 where
194 C: Into<Cookie<'static>>,
195 {
196 self.jar.$method_mut(&self.key).remove(cookie)
197 }
198 }
199
200 impl<'a, 'r, C, B, K> FromRequest<'a, WebContext<'r, C, B>> for $variant<K>
201 where
202 K: for<'a2, 'r2> FromRequest<'a2, WebContext<'r2, C, B>, Error = Error> + Into<Key>,
203 {
204 type Type<'b> = Self;
205 type Error = Error;
206
207 #[inline]
208 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
209 K::from_request(ctx).await.map(|key| $variant {
210 key: key.into(),
211 _key: PhantomData,
212 })
213 }
214 }
215 };
216}
217
218cookie_variant!(Private, private, private_mut);
219cookie_variant!(Signed, signed, signed_mut);
220
221impl<'a, 'r, C, B, K> FromRequest<'a, WebContext<'r, C, B>> for CookieJar<K>
222where
223 K: for<'a2, 'r2> FromRequest<'a2, WebContext<'r2, C, B>, Error = Error>,
224{
225 type Type<'b> = CookieJar<K>;
226 type Error = Error;
227
228 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
229 let key = K::from_request(ctx).await?;
230
231 let mut jar = _CookieJar::new();
232
233 let headers = ctx.req().headers();
234
235 if !headers.contains_key(COOKIE) {
236 return Err(Error::from(HeaderNotFound(COOKIE)));
237 }
238
239 for val in headers.get_all(COOKIE) {
240 for val in val.to_str()?.split(';') {
241 let cookie = Cookie::parse_encoded(val.to_owned())?;
242 jar.add_original(cookie);
243 }
244 }
245
246 Ok(CookieJar { jar, key })
247 }
248}
249
250error_from_service!(ToStrError);
251forward_blank_bad_request!(ToStrError);
252
253error_from_service!(ParseError);
254forward_blank_bad_request!(ParseError);
255
256impl<'r, C, B, K> Responder<WebContext<'r, C, B>> for CookieJar<K> {
257 type Response = WebResponse;
258 type Error = Error;
259
260 async fn respond(self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
261 let res = ctx.into_response(ResponseBody::empty());
262 Responder::<WebContext<'r, C, B>>::map(self, res)
263 }
264
265 fn map(self, mut res: Self::Response) -> Result<Self::Response, Self::Error> {
266 let headers = res.headers_mut();
267 for cookie in self.jar.delta() {
268 let value = HeaderValue::try_from(cookie.encoded().to_string()).map_err(|_| ErrorStatus::internal())?;
269 headers.append(SET_COOKIE, value);
270 }
271 Ok(res)
272 }
273}
274
275#[cfg(test)]
276mod test {
277 use xitca_unsafe_collection::futures::NowOrPanic;
278
279 use super::*;
280
281 #[test]
282 fn cookie() {
283 let mut ctx = WebContext::new_test(&());
284 let mut ctx = ctx.as_web_ctx();
285
286 let mut jar = CookieJar::plain();
287 jar.add(("foo", "bar"));
288
289 let cookie = jar
290 .respond(ctx.reborrow())
291 .now_or_panic()
292 .unwrap()
293 .headers_mut()
294 .remove(SET_COOKIE)
295 .unwrap();
296
297 ctx.req_mut().headers_mut().insert(COOKIE, cookie);
298
299 let mut jar: CookieJar = CookieJar::from_request(&ctx).now_or_panic().unwrap();
300
301 let val = jar.get("foo").unwrap();
302 assert_eq!(val.name(), "foo");
303 assert_eq!(val.value(), "bar");
304
305 jar.add(("996", "251"));
306
307 let res = CookieJar::respond(jar, ctx).now_or_panic().unwrap();
308
309 let header = res.headers().get(SET_COOKIE).unwrap();
310 assert_eq!(header.to_str().unwrap(), "996=251");
311 }
312
313 #[derive(Clone)]
314 struct MyKey(Key);
315
316 impl From<MyKey> for Key {
317 fn from(value: MyKey) -> Self {
318 value.0
319 }
320 }
321
322 impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for MyKey {
323 type Type<'b> = MyKey;
324 type Error = Error;
325
326 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
327 Ok(ctx.req().extensions().get::<MyKey>().unwrap().clone())
328 }
329 }
330
331 #[test]
332 fn private_cookie() {
333 let mut ctx = WebContext::new_test(&());
334 let mut ctx = ctx.as_web_ctx();
335
336 let key = Key::generate();
337
338 let mut jar = CookieJar::private(key.clone());
339 jar.add(("foo", "bar"));
340
341 let cookie = jar
342 .respond(ctx.reborrow())
343 .now_or_panic()
344 .unwrap()
345 .headers_mut()
346 .remove(SET_COOKIE)
347 .unwrap();
348
349 ctx.req_mut().headers_mut().insert(COOKIE, cookie);
350 ctx.req_mut().extensions_mut().insert(MyKey(key));
351
352 let jar = CookieJar::<Private<MyKey>>::from_request(&ctx).now_or_panic().unwrap();
353
354 let val = jar.get("foo").unwrap();
355 assert_eq!(val.name(), "foo");
356 assert_eq!(val.value(), "bar");
357 }
358
359 #[test]
360 fn signed_cookie() {
361 let mut ctx = WebContext::new_test(&());
362 let mut ctx = ctx.as_web_ctx();
363
364 let key = Key::generate();
365
366 let mut jar = CookieJar::signed(key.clone());
367 jar.add(("foo", "bar"));
368
369 let cookie = jar
370 .respond(ctx.reborrow())
371 .now_or_panic()
372 .unwrap()
373 .headers_mut()
374 .remove(SET_COOKIE)
375 .unwrap();
376
377 ctx.req_mut().headers_mut().insert(COOKIE, cookie);
378 ctx.req_mut().extensions_mut().insert(MyKey(key));
379
380 let jar = CookieJar::<Signed<MyKey>>::from_request(&ctx).now_or_panic().unwrap();
381
382 let val = jar.get("foo").unwrap();
383 assert_eq!(val.name(), "foo");
384 assert_eq!(val.value(), "bar");
385 }
386}