1#![warn(clippy::all, missing_docs, nonstandard_style, future_incompatible)]
42#![forbid(unsafe_code)]
43#![cfg_attr(docsrs, feature(doc_cfg))]
44
45use cookie::CookieJar;
46use http::HeaderValue;
47use parking_lot::Mutex;
48use std::sync::Arc;
49
50#[doc(inline)]
51pub use self::service::{CookieManager, CookieManagerLayer};
52
53#[cfg(feature = "signed")]
54pub use self::signed::SignedCookies;
55
56#[cfg(feature = "private")]
57pub use self::private::PrivateCookies;
58
59#[cfg(any(feature = "signed", feature = "private"))]
60pub use cookie::Key;
61
62pub use cookie::Cookie;
63
64#[doc(inline)]
65pub use cookie;
66
67#[cfg(feature = "axum-core")]
68#[cfg_attr(docsrs, doc(cfg(feature = "axum-core")))]
69mod extract;
70
71#[cfg(feature = "signed")]
72mod signed;
73
74#[cfg(feature = "private")]
75mod private;
76
77pub mod service;
78
79#[derive(Clone, Debug, Default)]
81pub struct Cookies {
82 inner: Arc<Mutex<Inner>>,
83}
84
85impl Cookies {
86 fn new(headers: Vec<HeaderValue>) -> Self {
87 let inner = Inner {
88 headers,
89 ..Default::default()
90 };
91 Self {
92 inner: Arc::new(Mutex::new(inner)),
93 }
94 }
95
96 pub fn add(&self, cookie: Cookie<'static>) {
99 let mut inner = self.inner.lock();
100 inner.changed = true;
101 inner.jar().add(cookie);
102 }
103
104 pub fn get(&self, name: &str) -> Option<Cookie> {
106 let mut inner = self.inner.lock();
107 inner.jar().get(name).cloned()
108 }
109
110 pub fn remove(&self, cookie: Cookie<'static>) {
112 let mut inner = self.inner.lock();
113 inner.changed = true;
114 inner.jar().remove(cookie);
115 }
116
117 pub fn list(&self) -> Vec<Cookie> {
122 let mut inner = self.inner.lock();
123 inner.jar().iter().cloned().collect()
124 }
125
126 #[cfg(feature = "signed")]
144 pub fn signed<'a>(&self, key: &'a cookie::Key) -> SignedCookies<'a> {
145 SignedCookies::new(self, key)
146 }
147
148 #[cfg(feature = "private")]
166 pub fn private<'a>(&self, key: &'a cookie::Key) -> PrivateCookies<'a> {
167 PrivateCookies::new(self, key)
168 }
169}
170
171#[derive(Debug, Default)]
172struct Inner {
173 headers: Vec<HeaderValue>,
174 jar: Option<CookieJar>,
175 changed: bool,
176}
177
178impl Inner {
179 fn jar(&mut self) -> &mut CookieJar {
180 if self.jar.is_none() {
181 let mut jar = CookieJar::new();
182 for header in &self.headers {
183 if let Ok(header_str) = std::str::from_utf8(header.as_bytes()) {
184 for cookie_str in header_str.split(';') {
185 if let Ok(cookie) = cookie::Cookie::parse_encoded(cookie_str.to_owned()) {
186 jar.add_original(cookie);
187 }
188 }
189 }
190 }
191 self.jar = Some(jar);
192 }
193 self.jar.as_mut().unwrap()
194 }
195}
196
197#[cfg(all(test, feature = "axum-core"))]
198mod tests {
199 use crate::{CookieManagerLayer, Cookies};
200 use axum::{body::Body, routing::get, Router};
201 use cookie::Cookie;
202 use http::{header, Request};
203 use http_body_util::BodyExt;
204 use tower::ServiceExt;
205
206 fn app() -> Router {
207 Router::new()
208 .route(
209 "/list",
210 get(|cookies: Cookies| async move {
211 let mut items = cookies
212 .list()
213 .iter()
214 .map(|c| format!("{}={}", c.name(), c.value()))
215 .collect::<Vec<_>>();
216 items.sort();
217 items.join(", ")
218 }),
219 )
220 .route(
221 "/add",
222 get(|cookies: Cookies| async move {
223 cookies.add(Cookie::new("baz", "3"));
224 cookies.add(Cookie::new("spam", "4"));
225 }),
226 )
227 .route(
228 "/remove",
229 get(|cookies: Cookies| async move {
230 cookies.remove(Cookie::new("foo", ""));
231 }),
232 )
233 .layer(CookieManagerLayer::new())
234 }
235
236 async fn body_string(body: Body) -> String {
237 let bytes = body.collect().await.unwrap().to_bytes();
238 String::from_utf8_lossy(&bytes).into()
239 }
240
241 #[tokio::test]
242 async fn read_cookies() {
243 let req = Request::builder()
244 .uri("/list")
245 .header(header::COOKIE, "foo=1; bar=2")
246 .body(Body::empty())
247 .unwrap();
248 let res = app().oneshot(req).await.unwrap();
249 assert_eq!(body_string(res.into_body()).await, "bar=2, foo=1");
250 }
251
252 #[tokio::test]
253 async fn read_multi_header_cookies() {
254 let req = Request::builder()
255 .uri("/list")
256 .header(header::COOKIE, "foo=1")
257 .header(header::COOKIE, "bar=2")
258 .body(Body::empty())
259 .unwrap();
260 let res = app().oneshot(req).await.unwrap();
261 assert_eq!(body_string(res.into_body()).await, "bar=2, foo=1");
262 }
263
264 #[tokio::test]
265 async fn add_cookies() {
266 let req = Request::builder()
267 .uri("/add")
268 .header(header::COOKIE, "foo=1; bar=2")
269 .body(Body::empty())
270 .unwrap();
271 let res = app().oneshot(req).await.unwrap();
272 let mut hdrs: Vec<_> = res.headers().get_all(header::SET_COOKIE).iter().collect();
273 hdrs.sort();
274 assert_eq!(hdrs, ["baz=3", "spam=4"]);
275 }
276
277 #[tokio::test]
278 async fn remove_cookies() {
279 let req = Request::builder()
280 .uri("/remove")
281 .header(header::COOKIE, "foo=1; bar=2")
282 .body(Body::empty())
283 .unwrap();
284 let res = app().oneshot(req).await.unwrap();
285 let mut hdrs = res.headers().get_all(header::SET_COOKIE).iter();
286 let hdr = hdrs.next().unwrap().to_str().unwrap();
287 assert!(hdr.starts_with("foo=; Max-Age=0"));
288 assert_eq!(hdrs.next(), None);
289 }
290}