1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
use futures_util::future::BoxFuture;
use http::{Request, Response};
use secstr::SecStr;
use std::{
    sync::Arc,
    task::{Context, Poll},
};
use tower_cookies::{
    cookie::{Expiration, SameSite},
    CookieManager, Cookies,
};
use tower_layer::Layer;
use tower_service::Service;

use crate::{guard::GuardService, Error, Token};

#[derive(Clone)]
pub(crate) struct Config {
    pub(crate) secret: SecStr,
    pub(crate) cookie_name: String,
    pub(crate) expires: Expiration,
    pub(crate) header_name: String,
    pub(crate) http_only: bool,
    pub(crate) prefix: bool,
    pub(crate) same_site: SameSite,
    pub(crate) secure: bool,
}

impl Config {
    pub(crate) fn cookie_name(&self) -> String {
        if self.prefix {
            format!("__HOST-{}", self.cookie_name)
        } else {
            self.cookie_name.clone()
        }
    }
}

/// A layer providing the [`Token`] extension.
///
/// On every request, it will create a cookie for the current visitor if
/// one has not already been created. The session ID used is an `i128`
/// generated by the [`rand`] crate.
#[derive(Clone)]
pub struct Surf {
    pub(crate) config: Config,
}

impl Surf {
    /// Creates a new [`Surf`] layer with the provided secret and default token configuration.
    pub fn new(secret: impl Into<String>) -> Self {
        Self {
            config: Config {
                secret: SecStr::from(secret.into()),
                cookie_name: "csrf_token".into(),
                expires: Expiration::Session,
                header_name: "X-CSRF-Token".into(),
                http_only: true,
                prefix: true,
                same_site: SameSite::Strict,
                secure: true,
            },
        }
    }

    /// Sets the cookie name. Note that this will be previed with `__HOST-` unless
    /// you have disabled it with [prefix](`Surf::prefix`). The default value is `csrf_token`.
    pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
        self.config.cookie_name = cookie_name.into();

        self
    }

    /// Sets the cookie's expiration. The default value is `Expiration::Session`.
    pub fn expires(mut self, expires: Expiration) -> Self {
        self.config.expires = expires;

        self
    }

    /// Sets the header name used when validating the request. The default
    /// value is `X-CSRF-Token`.
    pub fn header_name(mut self, header_name: impl Into<String>) -> Self {
        self.config.header_name = header_name.into();

        self
    }

    /// Sets the `HTTPOnly` attribute of the cookie. The default value is `true`.
    ///
    /// ⚠️ **Warning**: This should generally _not_ be set to false.
    /// See: [HttpOnly Cookie Attribute](https://owasp.org/www-community/HttpOnly).
    pub fn http_only(mut self, http_only: bool) -> Self {
        self.config.http_only = http_only;

        self
    }

    /// Sets whether to prefix the cookie name with `__HOST-`. The default
    /// value is `true`.
    ///
    /// See: [Cookie Name](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#cookie-namecookie-value).
    pub fn prefix(mut self, prefix: bool) -> Self {
        self.config.prefix = prefix;

        self
    }

    /// Sets the `SameSite` attribute of the cookie. The default value is [`SameSite::Strict`].
    ///
    /// See: [SameSite Cookie Attribute](https://owasp.org/www-community/SameSite).
    pub fn same_site(mut self, same_site: SameSite) -> Self {
        self.config.same_site = same_site;

        self
    }

    /// Sets the `secure` attribute of the cookie. Note that this is required to
    /// be `false` for cookies to work on `localhost`. The default value is `true`.
    ///
    /// See: [Secure Cookie Attribute](https://owasp.org/www-community/controls/SecureCookieAttribute).
    pub fn secure(mut self, secure: bool) -> Self {
        self.config.secure = secure;

        self
    }
}

impl<S> Layer<S> for Surf {
    type Service = CookieManager<SurfService<GuardService<S>>>;

    fn layer(&self, inner: S) -> Self::Service {
        CookieManager::new(SurfService {
            config: Arc::new(self.config.clone()),
            inner: GuardService::new(inner),
        })
    }
}

#[derive(Clone)]
pub struct SurfService<S> {
    config: Arc<Config>,
    inner: S,
}

impl<S, Q, R> Service<Request<Q>> for SurfService<S>
where
    S: Service<Request<Q>, Response = Response<R>> + Send + 'static,
    S::Future: Send + 'static,
    Q: Send + 'static,
    R: Default + Send,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut request: Request<Q>) -> Self::Future {
        let cookies = match request
            .extensions()
            .get::<Cookies>()
            .ok_or(Error::ExtensionNotFound("Cookies".into()))
        {
            Ok(cookies) => cookies,
            Err(err) => return Box::pin(async move { Error::make_layer_error(err) }),
        };

        let token = Token {
            config: self.config.clone(),
            cookies: cookies.clone(),
        };

        if cookies.get(&self.config.cookie_name()).is_none() {
            if let Err(err) = token.create() {
                return Box::pin(async move { Error::make_layer_error(err) });
            };
        }

        request.extensions_mut().insert(self.config.clone());
        request.extensions_mut().insert(token);

        Box::pin(self.inner.call(request))
    }
}