trillium_basic_auth/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11/*!
12Basic authentication for trillium.rs
13
14```rust,no_run
15use trillium_basic_auth::BasicAuth;
16trillium_smol::run((
17    BasicAuth::new("trillium", "7r1ll1um").with_realm("rust"),
18    |conn: trillium::Conn| async move { conn.ok("authenticated") },
19));
20```
21*/
22use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
23use trillium::{
24    async_trait, Conn, Handler,
25    KnownHeaderName::{Authorization, WwwAuthenticate},
26    Status,
27};
28
29/// basic auth handler
30#[derive(Clone, Debug)]
31pub struct BasicAuth {
32    credentials: Credentials,
33    realm: Option<String>,
34
35    // precomputed/derived data fields:
36    expected_header: String,
37    www_authenticate: String,
38}
39
40/// basic auth username-password credentials
41#[derive(Clone, Debug, PartialEq, Eq)]
42pub struct Credentials {
43    username: String,
44    password: String,
45}
46
47impl Credentials {
48    fn new(username: &str, password: &str) -> Self {
49        Self {
50            username: String::from(username),
51            password: String::from(password),
52        }
53    }
54
55    fn expected_header(&self) -> String {
56        format!(
57            "Basic {}",
58            BASE64.encode(format!("{}:{}", self.username, self.password))
59        )
60    }
61
62    // const BASIC: &str = "Basic ";
63    // pub fn for_conn(conn: &Conn) -> Option<Self> {
64    //     conn.request_headers()
65    //         .get_str(KnownHeaderName::Authorization)
66    //         .and_then(|value| {
67    //             if value[..BASIC.len().min(value.len())].eq_ignore_ascii_case(BASIC) {
68    //                 Some(&value[BASIC.len()..])
69    //             } else {
70    //                 None
71    //             }
72    //         })
73    //         .and_then(|base64_credentials| BASE64.decode(base64_credentials).ok())
74    //         .and_then(|credential_bytes| String::from_utf8(credential_bytes).ok())
75    //         .and_then(|mut credential_string| {
76    //             credential_string.find(":").map(|colon| {
77    //                 let password = credential_string.split_off(colon + 1).into();
78    //                 credential_string.pop();
79    //                 Self {
80    //                     username: credential_string.into(),
81    //                     password,
82    //                 }
83    //             })
84    //         })
85    // }
86}
87
88impl BasicAuth {
89    /// build a new basic auth handler with the provided username and password
90    pub fn new(username: &str, password: &str) -> Self {
91        let credentials = Credentials::new(username, password);
92        let expected_header = credentials.expected_header();
93        let realm = None;
94        Self {
95            expected_header,
96            credentials,
97            realm,
98            www_authenticate: String::from("Basic"),
99        }
100    }
101
102    /// provide a realm for the www-authenticate response sent by this handler
103    pub fn with_realm(mut self, realm: &str) -> Self {
104        self.www_authenticate = format!("Basic realm=\"{}\"", realm.replace('\"', "\\\""));
105        self.realm = Some(String::from(realm));
106        self
107    }
108
109    fn is_allowed(&self, conn: &Conn) -> bool {
110        conn.request_headers().get_str(Authorization) == Some(&*self.expected_header)
111    }
112
113    fn deny(&self, conn: Conn) -> Conn {
114        conn.with_status(Status::Unauthorized)
115            .with_response_header(WwwAuthenticate, self.www_authenticate.clone())
116            .halt()
117    }
118}
119
120#[async_trait]
121impl Handler for BasicAuth {
122    async fn run(&self, conn: Conn) -> Conn {
123        if self.is_allowed(&conn) {
124            conn.with_state(self.credentials.clone())
125        } else {
126            self.deny(conn)
127        }
128    }
129}