rwf/controller/middleware/
csrf.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//! CSRF (cross-site request forgery) protection.
//!
//! CSRF is a common phishing attack, fooling your
//! users into making changes to their data managed by your application.
//! CSRF protection ensures that forms submitted via POST
//! to the web app are coming from the form generated by the same website.
//!
//! ### Usage
//! CSRF protection is enabled by default. To make it work, include a Rwf-generated token
//! in all forms submitted via POST:
//!
//! ```html
//! <form method="post">
//!     <%= csrf_token() %>
//! </form>
//! ```
//!
//! If used via AJAX, include the CSRF token in the `X-CSRF-Token` header.
//! You can obtain the token by calling the `csrf_token_raw` template function:
//!
//! ```html
//! <script>
//!     window.csrf_token = "<%= csrf_token_raw() %>";
//! </script>
//! ```
//!
//! ### Configuration
//! Toggle `csrf_protection` in the configuration to enable/disable CSRF protection application-wide, e.g.:
//!
//! ```toml
//! [general]
//! csrf_protection = false
//! ```
use super::prelude::*;
use crate::{crypto::csrf_token_validate, http::Method};

/// CSRF HTTP header name.
pub static CSRF_HEADER: &str = "X-CSRF-Token";
/// CSRF HTTP form input name.
pub static CSRF_INPUT: &str = "rwf_csrf_token";

/// CSRF protection middleware.
pub struct Csrf;

impl Csrf {
    /// Create CSRF protection middleware.
    pub fn new() -> Self {
        Self {}
    }
}

#[async_trait]
impl Middleware for Csrf {
    async fn handle_request(&self, request: Request) -> Result<Outcome, Error> {
        if request.skip_csrf() {
            return Ok(Outcome::Forward(request));
        }

        if ![Method::Put, Method::Post, Method::Patch].contains(request.method()) {
            return Ok(Outcome::Forward(request));
        }

        let header = request.header(CSRF_HEADER);

        if let Some(header) = header {
            if csrf_token_validate(header) {
                return Ok(Outcome::Forward(request));
            }
        }

        match request.form_data() {
            Ok(form_data) => {
                if let Some(token) = form_data.get::<String>(CSRF_INPUT) {
                    if csrf_token_validate(&token) {
                        return Ok(Outcome::Forward(request));
                    }
                }
            }

            Err(_) => (),
        }

        Ok(Outcome::Stop(request, Response::csrf_error()))
    }
}