tower_surf/
surf.rs

1use futures_util::future::BoxFuture;
2use http::{HeaderValue, Request, Response};
3use secstr::SecStr;
4use std::{
5    sync::Arc,
6    task::{Context, Poll},
7};
8use tower_cookies::{
9    cookie::{Expiration, SameSite},
10    CookieManager, Cookies,
11};
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::{guard::GuardService, Error, Token};
16
17#[derive(Clone)]
18pub(crate) struct Config {
19    pub(crate) secret: SecStr,
20    pub(crate) cookie_name: String,
21    pub(crate) expires: Expiration,
22    pub(crate) header_name: String,
23    pub(crate) hsts: bool,
24    pub(crate) http_only: bool,
25    pub(crate) prefix: bool,
26    pub(crate) preload: bool,
27    pub(crate) same_site: SameSite,
28    pub(crate) secure: bool,
29}
30
31impl Config {
32    pub(crate) fn cookie_name(&self) -> String {
33        if self.prefix {
34            format!("__HOST-{}", self.cookie_name)
35        } else {
36            self.cookie_name.clone()
37        }
38    }
39}
40
41/// A layer providing the [`Token`] extension.
42///
43/// On every request, it will create a cookie for the current visitor if
44/// one has not already been created. The session ID used is an `i128`
45/// generated by the [`rand`] crate.
46#[derive(Clone)]
47pub struct Surf {
48    pub(crate) config: Config,
49}
50
51impl Surf {
52    /// Creates a new [`Surf`] layer with the provided secret and default token configuration.
53    pub fn new(secret: impl Into<String>) -> Self {
54        Self {
55            config: Config {
56                secret: SecStr::from(secret.into()),
57                cookie_name: "csrf_token".into(),
58                expires: Expiration::Session,
59                header_name: "X-CSRF-Token".into(),
60                hsts: true,
61                http_only: true,
62                prefix: true,
63                preload: false,
64                same_site: SameSite::Strict,
65                secure: true,
66            },
67        }
68    }
69
70    /// Sets the cookie name. Note that this will be previed with `__HOST-` unless
71    /// you have disabled it with [prefix](`Surf::prefix`). The default value is `csrf_token`.
72    pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
73        self.config.cookie_name = cookie_name.into();
74
75        self
76    }
77
78    /// Sets the cookie's expiration. The default value is `Expiration::Session`.
79    pub fn expires(mut self, expires: Expiration) -> Self {
80        self.config.expires = expires;
81
82        self
83    }
84
85    /// Sets the header name used when validating the request. The default
86    /// value is `X-CSRF-Token`.
87    pub fn header_name(mut self, header_name: impl Into<String>) -> Self {
88        self.config.header_name = header_name.into();
89
90        self
91    }
92
93    /// Sets whether to send the `Strict-Transport-Security` header.
94    ///
95    /// See: [HTTP Strict Transport Security Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Strict_Transport_Security_Cheat_Sheet.html)
96    pub fn hsts(mut self, hsts: bool) -> Self {
97        self.config.hsts = hsts;
98
99        self
100    }
101
102    /// Sets the `HTTPOnly` attribute of the cookie. The default value is `true`.
103    ///
104    /// ⚠️ **Warning**: This should generally _not_ be set to false.
105    /// See: [HttpOnly Cookie Attribute](https://owasp.org/www-community/HttpOnly).
106    pub fn http_only(mut self, http_only: bool) -> Self {
107        self.config.http_only = http_only;
108
109        self
110    }
111
112    /// Sets whether to prefix the cookie name with `__HOST-`. The default
113    /// value is `true`.
114    ///
115    /// See: [Cookie Name](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#cookie-namecookie-value).
116    pub fn prefix(mut self, prefix: bool) -> Self {
117        self.config.prefix = prefix;
118
119        self
120    }
121
122    /// Sets whether to append the [hsts](`Surf::hsts`) header with `preload`.
123    ///
124    /// See: [HTTP Strict Transport Security Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Strict_Transport_Security_Cheat_Sheet.html)
125    pub fn preload(mut self, preload: bool) -> Self {
126        self.config.preload = preload;
127
128        self
129    }
130
131    /// Sets the `SameSite` attribute of the cookie. The default value is [`SameSite::Strict`].
132    ///
133    /// See: [SameSite Cookie Attribute](https://owasp.org/www-community/SameSite).
134    pub fn same_site(mut self, same_site: SameSite) -> Self {
135        self.config.same_site = same_site;
136
137        self
138    }
139
140    /// Sets the `secure` attribute of the cookie. Note that this is required to
141    /// be `false` for cookies to work on `localhost`. The default value is `true`.
142    ///
143    /// See: [Secure Cookie Attribute](https://owasp.org/www-community/controls/SecureCookieAttribute).
144    pub fn secure(mut self, secure: bool) -> Self {
145        self.config.secure = secure;
146
147        self
148    }
149}
150
151impl<S> Layer<S> for Surf {
152    type Service = CookieManager<SurfService<GuardService<S>>>;
153
154    fn layer(&self, inner: S) -> Self::Service {
155        CookieManager::new(SurfService {
156            config: Arc::new(self.config.clone()),
157            inner: GuardService::new(inner),
158        })
159    }
160}
161
162#[derive(Clone)]
163pub struct SurfService<S> {
164    config: Arc<Config>,
165    inner: S,
166}
167
168impl<S, Q, R> Service<Request<Q>> for SurfService<S>
169where
170    S: Service<Request<Q>, Response = Response<R>> + Send + 'static,
171    S::Future: Send + 'static,
172    Q: Send + 'static,
173    R: Default + Send,
174{
175    type Response = S::Response;
176    type Error = S::Error;
177    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
178
179    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180        self.inner.poll_ready(cx)
181    }
182
183    fn call(&mut self, mut request: Request<Q>) -> Self::Future {
184        let cookies = match request
185            .extensions()
186            .get::<Cookies>()
187            .ok_or(Error::ExtensionNotFound("Cookies".into()))
188        {
189            Ok(cookies) => cookies,
190            Err(err) => return Box::pin(async move { Error::make_layer_error(err) }),
191        };
192
193        let token = Token {
194            config: self.config.clone(),
195            cookies: cookies.clone(),
196        };
197
198        if cookies.get(&self.config.cookie_name()).is_none() {
199            if let Err(err) = token.create() {
200                return Box::pin(async move { Error::make_layer_error(err) });
201            };
202        }
203
204        request.extensions_mut().insert(self.config.clone());
205        request.extensions_mut().insert(token);
206
207        let config = self.config.clone();
208
209        if config.hsts {
210            let future = self.inner.call(request);
211
212            Box::pin(async move {
213                let mut response = future.await?;
214
215                let mut value = "max-age=31536000; includeSubDomains".to_owned();
216
217                if config.preload {
218                    value.push_str("; preload");
219                }
220
221                let value = match HeaderValue::from_str(&value) {
222                    Ok(value) => value,
223                    Err(err) => return Error::make_layer_error(err),
224                };
225
226                response
227                    .headers_mut()
228                    .insert("Strict-Transport-Security", value);
229
230                Ok(response)
231            })
232        } else {
233            Box::pin(self.inner.call(request))
234        }
235    }
236}