1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use std::sync::Arc;

use base64::prelude::*;
use hmac::{Hmac, Mac};
use rand::prelude::*;
use secstr::SecStr;
use sha2::Sha256;
use tower_cookies::{Cookie, Cookies};

use crate::{error::Error, surf::Config};

/// An extension providing a way to interact with a visitor's
/// CSRF token.
#[derive(Clone)]
pub struct Token {
    pub(crate) config: Arc<Config>,
    pub(crate) cookies: Cookies,
}

impl Token {
    pub(crate) fn create(&self) -> Result<(), Error> {
        let identifier: i128 = thread_rng().gen();
        let token = create_token(&self.config.secret, identifier.to_string())?;

        let cookie = Cookie::build((self.config.cookie_name(), token))
            .path("/")
            .expires(self.config.expires)
            .http_only(self.config.http_only)
            .same_site(self.config.same_site)
            .secure(self.config.secure)
            .build();

        self.cookies.add(cookie);

        Ok(())
    }

    /// Updates the identifier used to sign the token. The value should only be valid for the
    /// duration of the user's authenticated session and should be unique to that session.
    ///
    /// See: [OWASP's CSRF Prevention Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#employing-hmac-csrf-tokens).
    ///
    /// # Errors
    ///
    /// - [`Error::InvalidLength`]
    pub fn set(&self, identifier: impl Into<String>) -> Result<(), Error> {
        let token = create_token(&self.config.secret, identifier)?;

        let cookie = Cookie::build((self.config.cookie_name(), token))
            .path("/")
            .expires(self.config.expires)
            .http_only(self.config.http_only)
            .same_site(self.config.same_site)
            .secure(self.config.secure)
            .build();

        self.cookies.add(cookie);

        Ok(())
    }

    /// Get the current visitor's token.
    ///
    /// # Errors
    ///
    /// - [`Error::NoCookie`]
    pub fn get(&self) -> Result<String, Error> {
        self.cookies
            .get(&self.config.cookie_name())
            .map(|cookie| cookie.value().to_owned())
            .ok_or(Error::NoCookie)
    }

    /// Reset the token to an identifier generated by [Surf](`crate::Surf`).
    pub fn reset(&self) {
        let cookie = Cookie::build((self.config.cookie_name(), "")).build();

        self.cookies.remove(cookie);
    }
}

type HmacSha256 = Hmac<Sha256>;

pub(crate) fn create_token(
    secret: &SecStr,
    identifier: impl Into<String>,
) -> Result<String, Error> {
    let random = BASE64_STANDARD.encode(get_random_value());
    let message = format!("{}!{}", identifier.into(), random);
    let result = sign_and_encode(secret, &message)?;
    let token = format!("{}.{}", result, message);

    Ok(token)
}

pub(crate) fn validate_token(secret: &SecStr, cookie: &str, token: &str) -> Result<bool, Error> {
    let mut parts = token.splitn(2, '.');
    let received_hmac = parts.next().unwrap_or("");

    let message = parts.next().unwrap_or("");
    let expected_hmac = sign_and_encode(secret, message)?;

    Ok(received_hmac == expected_hmac && cookie == token)
}

#[cfg(not(test))]
fn get_random_value() -> [u8; 64] {
    let mut random = [0u8; 64];
    thread_rng().fill(&mut random);

    random
}

#[cfg(test)]
fn get_random_value() -> [u8; 64] {
    [42u8; 64]
}

fn sign_and_encode(secret: &SecStr, message: &str) -> Result<String, Error> {
    let mut mac = HmacSha256::new_from_slice(secret.unsecure())?;
    mac.update(message.as_bytes());
    let result = BASE64_STANDARD.encode(mac.finalize().into_bytes());

    Ok(result)
}

#[cfg(test)]
mod tests {
    use anyhow::Result;

    use super::*;

    #[test]
    fn create_token() -> Result<()> {
        let secret = SecStr::from("super-secret");
        let token = super::create_token(&secret, "identifier")?;

        let parts = token.splitn(2, '.').collect::<Vec<&str>>();
        assert_eq!(parts.len(), 2);

        let message = format!("{}!{}", "identifier", BASE64_STANDARD.encode([42u8; 64]));
        assert_eq!(parts[1], message);

        let signature = sign_and_encode(&secret, &message)?;
        assert_eq!(parts[0], signature);

        Ok(())
    }
}