tower_cookies/
lib.rs

1//! A cookie manager middleware built on top of [tower].
2//!
3//! ## Example
4//!
5//! With [axum]:
6//!
7//! ```rust,no_run
8//! use axum::{routing::get, Router};
9//! use std::net::SocketAddr;
10//! use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
11//!
12//! # #[cfg(feature = "axum-core")]
13//! #[tokio::main]
14//! async fn main() {
15//!     let app = Router::new()
16//!         .route("/", get(handler))
17//!         .layer(CookieManagerLayer::new());
18//!
19//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
20//!     let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
21//!     axum::serve(listener, app.into_make_service())
22//!         .await
23//!         .unwrap();
24//! }
25//! # #[cfg(not(feature = "axum-core"))]
26//! # fn main() {}
27//!
28//! async fn handler(cookies: Cookies) -> &'static str {
29//!     cookies.add(Cookie::new("hello_world", "hello_world"));
30//!
31//!     "Check your cookies."
32//! }
33//! ```
34//!
35//! A complete CRUD cookie example in [examples/counter.rs][example]
36//!
37//! [axum]: https://crates.io/crates/axum
38//! [tower]: https://crates.io/crates/tower
39//! [example]: https://github.com/imbolc/tower-cookies/blob/main/examples/counter.rs
40
41#![warn(clippy::all, missing_docs, nonstandard_style, future_incompatible)]
42#![forbid(unsafe_code)]
43#![cfg_attr(docsrs, feature(doc_cfg))]
44
45use cookie::CookieJar;
46use http::HeaderValue;
47use parking_lot::Mutex;
48use std::sync::Arc;
49
50#[doc(inline)]
51pub use self::service::{CookieManager, CookieManagerLayer};
52
53#[cfg(feature = "signed")]
54pub use self::signed::SignedCookies;
55
56#[cfg(feature = "private")]
57pub use self::private::PrivateCookies;
58
59#[cfg(any(feature = "signed", feature = "private"))]
60pub use cookie::Key;
61
62pub use cookie::Cookie;
63
64#[doc(inline)]
65pub use cookie;
66
67#[cfg(feature = "axum-core")]
68#[cfg_attr(docsrs, doc(cfg(feature = "axum-core")))]
69mod extract;
70
71#[cfg(feature = "signed")]
72mod signed;
73
74#[cfg(feature = "private")]
75mod private;
76
77pub mod service;
78
79/// A parsed on-demand cookie jar.
80#[derive(Clone, Debug, Default)]
81pub struct Cookies {
82    inner: Arc<Mutex<Inner>>,
83}
84
85impl Cookies {
86    fn new(headers: Vec<HeaderValue>) -> Self {
87        let inner = Inner {
88            headers,
89            ..Default::default()
90        };
91        Self {
92            inner: Arc::new(Mutex::new(inner)),
93        }
94    }
95
96    /// Adds [`Cookie`] to this jar. If a [`Cookie`] with the same name already exists, it is
97    /// replaced with provided cookie.
98    pub fn add(&self, cookie: Cookie<'static>) {
99        let mut inner = self.inner.lock();
100        inner.changed = true;
101        inner.jar().add(cookie);
102    }
103
104    /// Returns the [`Cookie`] with the given name. Returns [`None`] if it doesn't exist.
105    pub fn get(&self, name: &str) -> Option<Cookie> {
106        let mut inner = self.inner.lock();
107        inner.jar().get(name).cloned()
108    }
109
110    /// Removes [`Cookie`] from this jar.
111    pub fn remove(&self, cookie: Cookie<'static>) {
112        let mut inner = self.inner.lock();
113        inner.changed = true;
114        inner.jar().remove(cookie);
115    }
116
117    /// Returns all the [`Cookie`]s present in this jar.
118    ///
119    /// This method collects [`Cookie`]s into a vector instead of iterating through them to
120    /// minimize the mutex locking time.
121    pub fn list(&self) -> Vec<Cookie> {
122        let mut inner = self.inner.lock();
123        inner.jar().iter().cloned().collect()
124    }
125
126    /// Returns a child [`SignedCookies`] jar for interations with signed by the `key` cookies.
127    ///
128    /// # Example:
129    /// ```
130    /// use cookie::{Cookie, Key};
131    /// use tower_cookies::Cookies;
132    ///
133    /// let cookies = Cookies::default();
134    /// let key = Key::generate();
135    /// let signed = cookies.signed(&key);
136    ///
137    /// let foo = Cookie::new("foo", "bar");
138    /// signed.add(foo.clone());
139    ///
140    /// assert_eq!(signed.get("foo"), Some(foo.clone()));
141    /// assert_ne!(cookies.get("foo"), Some(foo));
142    /// ```
143    #[cfg(feature = "signed")]
144    pub fn signed<'a>(&self, key: &'a cookie::Key) -> SignedCookies<'a> {
145        SignedCookies::new(self, key)
146    }
147
148    /// Returns a child [`PrivateCookies`] jar for encrypting and decrypting cookies.
149    ///
150    /// # Example:
151    /// ```
152    /// use cookie::{Cookie, Key};
153    /// use tower_cookies::Cookies;
154    ///
155    /// let cookies = Cookies::default();
156    /// let key = Key::generate();
157    /// let private = cookies.private(&key);
158    ///
159    /// let foo = Cookie::new("foo", "bar");
160    /// private.add(foo.clone());
161    ///
162    /// assert_eq!(private.get("foo"), Some(foo.clone()));
163    /// assert_ne!(cookies.get("foo"), Some(foo));
164    /// ```
165    #[cfg(feature = "private")]
166    pub fn private<'a>(&self, key: &'a cookie::Key) -> PrivateCookies<'a> {
167        PrivateCookies::new(self, key)
168    }
169}
170
171#[derive(Debug, Default)]
172struct Inner {
173    headers: Vec<HeaderValue>,
174    jar: Option<CookieJar>,
175    changed: bool,
176}
177
178impl Inner {
179    fn jar(&mut self) -> &mut CookieJar {
180        if self.jar.is_none() {
181            let mut jar = CookieJar::new();
182            for header in &self.headers {
183                if let Ok(header_str) = std::str::from_utf8(header.as_bytes()) {
184                    for cookie_str in header_str.split(';') {
185                        if let Ok(cookie) = cookie::Cookie::parse_encoded(cookie_str.to_owned()) {
186                            jar.add_original(cookie);
187                        }
188                    }
189                }
190            }
191            self.jar = Some(jar);
192        }
193        self.jar.as_mut().unwrap()
194    }
195}
196
197#[cfg(all(test, feature = "axum-core"))]
198mod tests {
199    use crate::{CookieManagerLayer, Cookies};
200    use axum::{body::Body, routing::get, Router};
201    use cookie::Cookie;
202    use http::{header, Request};
203    use http_body_util::BodyExt;
204    use tower::ServiceExt;
205
206    fn app() -> Router {
207        Router::new()
208            .route(
209                "/list",
210                get(|cookies: Cookies| async move {
211                    let mut items = cookies
212                        .list()
213                        .iter()
214                        .map(|c| format!("{}={}", c.name(), c.value()))
215                        .collect::<Vec<_>>();
216                    items.sort();
217                    items.join(", ")
218                }),
219            )
220            .route(
221                "/add",
222                get(|cookies: Cookies| async move {
223                    cookies.add(Cookie::new("baz", "3"));
224                    cookies.add(Cookie::new("spam", "4"));
225                }),
226            )
227            .route(
228                "/remove",
229                get(|cookies: Cookies| async move {
230                    cookies.remove(Cookie::new("foo", ""));
231                }),
232            )
233            .layer(CookieManagerLayer::new())
234    }
235
236    async fn body_string(body: Body) -> String {
237        let bytes = body.collect().await.unwrap().to_bytes();
238        String::from_utf8_lossy(&bytes).into()
239    }
240
241    #[tokio::test]
242    async fn read_cookies() {
243        let req = Request::builder()
244            .uri("/list")
245            .header(header::COOKIE, "foo=1; bar=2")
246            .body(Body::empty())
247            .unwrap();
248        let res = app().oneshot(req).await.unwrap();
249        assert_eq!(body_string(res.into_body()).await, "bar=2, foo=1");
250    }
251
252    #[tokio::test]
253    async fn read_multi_header_cookies() {
254        let req = Request::builder()
255            .uri("/list")
256            .header(header::COOKIE, "foo=1")
257            .header(header::COOKIE, "bar=2")
258            .body(Body::empty())
259            .unwrap();
260        let res = app().oneshot(req).await.unwrap();
261        assert_eq!(body_string(res.into_body()).await, "bar=2, foo=1");
262    }
263
264    #[tokio::test]
265    async fn add_cookies() {
266        let req = Request::builder()
267            .uri("/add")
268            .header(header::COOKIE, "foo=1; bar=2")
269            .body(Body::empty())
270            .unwrap();
271        let res = app().oneshot(req).await.unwrap();
272        let mut hdrs: Vec<_> = res.headers().get_all(header::SET_COOKIE).iter().collect();
273        hdrs.sort();
274        assert_eq!(hdrs, ["baz=3", "spam=4"]);
275    }
276
277    #[tokio::test]
278    async fn remove_cookies() {
279        let req = Request::builder()
280            .uri("/remove")
281            .header(header::COOKIE, "foo=1; bar=2")
282            .body(Body::empty())
283            .unwrap();
284        let res = app().oneshot(req).await.unwrap();
285        let mut hdrs = res.headers().get_all(header::SET_COOKIE).iter();
286        let hdr = hdrs.next().unwrap().to_str().unwrap();
287        assert!(hdr.starts_with("foo=; Max-Age=0"));
288        assert_eq!(hdrs.next(), None);
289    }
290}