tide_http_auth/
scheme.rs

1use crate::storage::Storage;
2use http_types::{Result, StatusCode};
3
4use std::any::Any;
5
6#[async_trait::async_trait]
7pub trait Scheme<User: Send + Sync + 'static> {
8    type Request: Any + Send + Sync;
9
10    async fn authenticate<S>(&self, state: &S, auth_param: &str) -> Result<Option<User>>
11    where
12        S: Storage<User, Self::Request> + Send + Sync + 'static;
13
14    fn should_401_on_multiple_values() -> bool {
15        true
16    }
17    fn should_403_on_bad_auth() -> bool {
18        true
19    }
20
21    fn header_name() -> &'static str {
22        "Authorization"
23    }
24    fn scheme_name() -> &'static str;
25}
26
27#[derive(Default, Debug)]
28pub struct BasicAuthScheme;
29
30#[derive(Debug)]
31pub struct BasicAuthRequest {
32    pub username: String,
33    pub password: String,
34}
35
36#[async_trait::async_trait]
37impl<User: Send + Sync + 'static> Scheme<User> for BasicAuthScheme {
38    type Request = BasicAuthRequest;
39
40    async fn authenticate<S>(&self, state: &S, auth_param: &str) -> Result<Option<User>>
41    where
42        S: Storage<User, Self::Request> + Send + Sync + 'static,
43    {
44        let bytes = base64::decode(auth_param);
45        if bytes.is_err() {
46            // This is invalid. Fail the request.
47            return Err(http_types::Error::from_str(
48                StatusCode::Unauthorized,
49                "Basic auth param must be valid base64.",
50            ));
51        }
52
53        let as_utf8 = String::from_utf8(bytes.unwrap());
54        if as_utf8.is_err() {
55            // You know the drill.
56            return Err(http_types::Error::from_str(
57                StatusCode::Unauthorized,
58                "Basic auth param base64 must contain valid utf-8.",
59            ));
60        }
61
62        let as_utf8 = as_utf8.unwrap();
63        let parts: Vec<_> = as_utf8.split(':').collect();
64
65        if parts.len() < 2 {
66            return Ok(None);
67        }
68
69        let (username, password) = (parts[0], parts[1]);
70
71        let user = state
72            .get_user(BasicAuthRequest {
73                username: username.to_owned(),
74                password: password.to_owned(),
75            })
76            .await?;
77
78        Ok(user)
79    }
80
81    fn scheme_name() -> &'static str {
82        "Basic "
83    }
84}
85
86#[derive(Default, Debug)]
87pub struct BearerAuthScheme {
88    prefix: String,
89}
90
91pub struct BearerAuthRequest {
92    pub token: String,
93}
94
95#[async_trait::async_trait]
96impl<User: Send + Sync + 'static> Scheme<User> for BearerAuthScheme {
97    type Request = BearerAuthRequest;
98
99    async fn authenticate<S>(&self, state: &S, auth_param: &str) -> Result<Option<User>>
100    where
101        S: Storage<User, Self::Request> + Send + Sync + 'static,
102    {
103        if !auth_param.starts_with(self.prefix.as_str()) {
104            return Ok(None);
105        }
106
107        // TODO: validate that the auth_param (sans the prefix) is a valid uuid.
108        let user = state
109            .get_user(BearerAuthRequest {
110                token: (&auth_param[self.prefix.len()..]).to_owned(),
111            })
112            .await?;
113        Ok(user)
114    }
115
116    fn scheme_name() -> &'static str {
117        "Bearer "
118    }
119}